In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
path = os.path.join(os.getcwd(), "datasets", "animals10", "raw-img")
class_list = os.listdir(path)
class_list.remove('.DS_Store')
data_paths = [os.listdir(os.path.join(path, c)) for c in class_list]
sample_n = [len(c) for c in data_paths]
print(class_list)

In [None]:
import numpy as np
class Dataset:
    def __init__(self, BS, shuffle=True):
        self.BS = BS
        self.XIndex = list()
        self.y = list()
        for i, yi in enumerate(sample_n):
            for xi in range(yi):
                self.y.append(i)
                self.XIndex.append(xi)
        self.XIndex = np.array(self.XIndex)
        self.y = np.array(self.y)
        self.l = len(self.y)
        if shuffle:
            randl = list(range(self.l))
            np.random.shuffle(randl)
            self.XIndex = self.XIndex[randl]
            self.y = self.y[randl]
    def __len__(self):
        return int(np.ceil(self.l/self.BS))
    def __getitem__(self, key):
        start = key * self.BS
        assert start < self.l
        stop = start + self.BS
        if stop > self.l and start < self.l:
            stop = self.l
        return self.XIndex[start:stop], self.y[start:stop]
    def __iter__(self):
        for b in range(len(self)):
            start = b*self.BS
            X, y = self.XIndex[start:start+self.BS], self.y[start:start+self.BS]
            X_path = list()
            for xs, ys in zip(X, y):
                X_path.append(os.path.join(path,class_list[ys],data_paths[ys][xs]))
            X_path = np.array(X_path)
            X_Images = self.toMem(X_path)
            yield X_Images, y
    def toMem(self, X_path):
        from PIL import Image
        X_images = list()
        for img_path in X_path:
            im = Image.open(img_path)
            X_images.append(im)
        return X_images
        
'''
getdata = Dataset(BS=150)
for x_b, y_b in getdata:
    print(type(x_b[0]))
    # print(x_b, y_b, x_b.shape, y_b.shape)
'''

In [None]:
import torch
# torch.manual_seed(17)
from torchvision import datasets, transforms
BS = 8
IM_HEIGHT = 425
IM_WIDTH = 600
transform = transforms.Compose([transforms.Resize([IM_HEIGHT, IM_WIDTH]), transforms.ToTensor()])
path = os.path.join(os.getcwd(), "datasets", "animals10", "raw-img")
dataset = datasets.ImageFolder(path, transform=transform)
train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.7), int(len(dataset)*0.3)+1])
train_set_loader = torch.utils.data.DataLoader(train_set, batch_size=BS, shuffle=True)
val_set_loader = torch.utils.data.DataLoader(val_set, batch_size=BS, shuffle=True)

In [None]:
import matplotlib.pyplot as plt
i, l = next(iter(train_set_loader))
print(i.shape, i[0].shape)
sample = i[0][2]
print(sample)
plt.imshow(i[0][2].reshape(IM_HEIGHT, IM_WIDTH))
plt.show()
print(i[0][0])
print(torch.min(i[0][0]))
print(torch.max(i[0][0]))

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class someNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(IM_HEIGHT*IM_WIDTH*3, IM_HEIGHT*IM_WIDTH)
        self.l2 = nn.Linear(IM_HEIGHT*IM_WIDTH, 10)
        nn.init.normal_(self.l1.weight, mean=0, std=0.1)
        nn.init.normal_(self.l2.weight, mean=0, std=0.1)
    def forward(self, x):
        x = self.l1(x)
        x = F.relu(x)
        x = self.l2(x)
        x = F.log_softmax(x)
        return x

model = someNet()

In [None]:
WAN = False
N_EPOCH = 10
LR = 0.01

In [None]:
if WAN:
  import wandb
  wandb.init(project="animals10", entity="0xasim")
  wandb.config = {
    "learning_rate": LR,
    "epochs": N_EPOCH,
    "batch_size": BS
  }

In [None]:
from tqdm import tqdm
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
review = lambda x_b: x_b.view(len(x_b), -1)
def getValLoss():
    with torch.no_grad():
        val_loss = list()
        for x_v_b, y_v_b in tqdm(val_set_loader):
            x_v_b = torch.tensor([torch.flatten(e_x).numpy() for e_x in x_v_b])
            pred = model.forward(x_v_b)
            loss = loss_fn(pred, y_v_b)
            val_loss.append(loss.sum()/y_v_b.shape[0])
            print(val_loss)
        return val_loss.sum().item()

def train():
    for i in (range(N_EPOCH)):
        batch_loss = list()
        for x_b, y_b in tqdm(train_set_loader):
            x_b = review(x_b)
            pred = model.forward(x_b)
            loss = loss_fn(pred, y_b)
            # clear gradients for this training step
            optimizer.zero_grad()
            # backprop, compute gradients
            loss.backward()
            # apply gradients
            optimizer.step()
            batch_loss.append(loss.sum()/y_b.shape[0])
            if WAN:
                wandb.log({"train_loss": loss.sum().item()})
        if WAN:
            wandb.log({"val_loss": getValLoss()})
train()