In [1]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

from leafy.loader import ImageLoader
from leafy.trainloader import ImageDataset

data_folder = Path("./images")
loader = ImageLoader(data_folder=data_folder)
image_db = ImageDataset(loader)

# class_distribution = loader.get_better_class_distribution()
im, y = image_db[0]
num_classes = len(y)
print(f"{im.shape = }\n{y.shape = }")

Initiated loader on folder /home/joep/Code/Leafliction/images. Found 7233 images.
im.shape = torch.Size([3, 256, 256])
y.shape = torch.Size([8])


In [6]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from icecream import ic
from torch.nn import Module

class BasicClassifier(Module):
    def __init__(self, num_classes):
        super(BasicClassifier, self).__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(2048, 8)
        
    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x.view(x.size(0), -1))
        return x
        # return F.softmax(x, dim=1) # will break for unbatched data


net = BasicClassifier(num_classes = 8)
net = net.cuda()


preprocess = ResNet50_Weights.DEFAULT.transforms()
# resnet = resnet50(weights=ResNet50_Weights.DEFAULT)


In [7]:
from tqdm.notebook import tqdm

from torch.utils.data import DataLoader
from torch.optim import Adam

trainloader = DataLoader(image_db, 64, shuffle = True)
ce_loss = nn.CrossEntropyLoss()

optim = Adam(net.parameters())
for x, y in tqdm(trainloader):
    optim.zero_grad()
    x = x.cuda()
    x = preprocess(x)
    y = y.cuda()
    y_hat = net(x)
    loss = ce_loss(y_hat, y)
    loss.backward()
    optim.step()
    print(loss)


  0%|          | 0/114 [00:00<?, ?it/s]



tensor(2.0994, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.7465, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.3465, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.8431, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.5081, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.4990, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.2739, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.2157, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.1069, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.0806, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.2069, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.1101, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.1561, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.1209, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.3574, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.0600, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.0905, device='cuda:0', grad_fn=<DivBackward1>)
tensor(0.4362, device='cuda:0', grad_fn=<DivBack

In [9]:
for x, y in tqdm(trainloader):
    break

net = net.cpu()

y_hat = F.softmax(net(preprocess(x)))
print(y, y_hat)

  0%|          | 0/114 [00:00<?, ?it/s]



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0

  y_hat = F.softmax(net(preprocess(x)))


In [13]:
y_index = torch.argmax(y, dim=1)
y_hat_index = torch.argmax(y_hat, dim=1)

In [15]:
y_index == y_hat_index

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True, False,  True, False,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True])