In [32]:
import torchvision.transforms as tfms

mean = (0.5225, 0.483, 0.3616)
std = (0.2775, 0.2593, 0.2915)

train_tfms = tfms.Compose([
    #already resized to 300x300
    tfms.RandomResizedCrop(256),
    tfms.RandomHorizontalFlip(),
    tfms.ToTensor(),
    tfms.Normalize(mean,std)
])

val_tfms = tfms.Compose([
    tfms.Resize(300),
    tfms.CenterCrop(256),
    tfms.ToTensor(),
    tfms.Normalize(mean,std)
])

In [33]:
from torchvision.datasets import ImageFolder

root = "./data/hymenoptera/"

train_ds = ImageFolder(root+"train", train_tfms)
val_ds = ImageFolder(root+"val", val_tfms)

img_cls = train_ds.classes
print(img_cls)

['ants', 'bees']


In [34]:
from torch.utils.data.dataloader import DataLoader

train_dl = DataLoader(train_ds, 50,shuffle=True,pin_memory=True)
val_dl = DataLoader(val_ds, 20,shuffle=False,pin_memory=True)

In [35]:
# import matplotlib.pyplot as plt
# from mpl_toolkits.axes_grid1 import ImageGrid

# fig = plt.figure(figsize=(10,10))
# grid = ImageGrid(fig,111,(5,5))


# for images,labels in val_dl:
#     image_arr = images.permute(0,2,3,1)

#     for axis,image in zip(grid,image_arr):
#         axis.imshow(image)
#     plt.show()
#     break

In [36]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [37]:
import torch.nn as nn

# class ResNet(nn.Module):
#     def __init__(self):
#         super().__init__()
        
#         self.norm15 = nn.BatchNorm2d(15)
#         self.norm60 = nn.BatchNorm2d(60)
#         self.norm120 = nn.BatchNorm2d(120)
#         self.norm240 = nn.BatchNorm2d(240)
        
        
#         self.conv15 = nn.Conv2d(3,15,3,1,1)
        
#         self.conv60a = nn.Conv2d(15,60,3,1,1)
#         self.conv60b = nn.Conv2d(60,60,3,1,1)
        
#         self.conv120a = nn.Conv2d(60,120,3,1,1)
#         self.conv120b = nn.Conv2d(120,120,3,1,1)
#         self.conv120c = nn.Conv2d(120,120,3,1,1)
        
#         self.conv240a = nn.Conv2d(120,240,3,1,1)
#         self.conv240b = nn.Conv2d(240,240,3,1,1)
#         self.conv240c = nn.Conv2d(240,240,3,1,1)
        
        
#         self.pool = nn.MaxPool2d(2,2)
#         self.aapool = nn.AdaptiveAvgPool2d(1)
        
#         self.flat = nn.Flatten()
        
#         self.linear = nn.Linear(240,2)
        
    
#     def forward(self,data):
#         out = torch.relu(self.norm15(self.conv15(data)))# 15 256 256
#         out = self.pool(out)#15 128 128
        
#         out = torch.relu(self.norm60(self.conv60a(out)))# 60 128 128
#         x = out
#         out = torch.relu(self.conv60b(out)+x)
        
#         out = self.pool(out)#60 64 64
        
#         x = self.norm120(self.conv240a(out))#120 64 64
#         out = torch.relu(x)
#         out = torch.relu(self.conv120b(out)+x)
#         out = torch.relu(self.conv120c(out))
        
#         out = self.pool(out)#120 32 32
        
#         x = self.norm240(self.conv240a(out))# 240 32 32
#         out = torch.relu(x)
#         out = torch.relu(self.conv240b(out)+x)
#         out = torch.relu(self.conv240c(out))
        
#         out = self.aapool(out)#240 1 1 
        
#         out = self.flat(out)#240
#         out = self.linear(out)#2
        
#         out = torch.softmax(out, dim=-1)
            
#         return out
        
        

In [38]:
from torchvision import models

model = models.resnet18(pretrained=1)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs,2)

model = model.to(device)

In [39]:
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(),lr=0.001)
print(len(val_dl))

4


In [40]:
import numpy as np

def fit(epochs):
    for epoch in range(epochs):
        print("Epoch: ",epoch+1)

        for images,labels in train_dl:
            preds = model(images.to(device))
                
            loss = loss_fn(preds.to(device), labels.to(device))
            loss.backward()
                
            opt.step()
            opt.zero_grad()
            break

        print("Loss: ",round(loss.item(),6))
        
        
        with torch.no_grad():
            correct = 0
            for images,labels in val_dl:
                preds = model(images.to(device))
                
                for i in range(len(preds)):
                    if (preds[i].max()==preds[i][labels[i].item()]):
                        correct += 1
                
            acc = correct/len(val_ds)
            print("Accuracy: ",round(acc*100,2))
            print("")
        if acc>0.80:
            break


In [43]:
fit(20)

Epoch:  1
Loss:  0.244928
Accuracy:  0.5875

Epoch:  2
Loss:  0.170394
Accuracy:  0.6

Epoch:  3
Loss:  0.416226
Accuracy:  0.575

Epoch:  4
Loss:  0.15215
Accuracy:  0.5625

Epoch:  5
Loss:  0.214524
Accuracy:  0.55

Epoch:  6
Loss:  0.214051
Accuracy:  0.525

Epoch:  7
Loss:  0.137485
Accuracy:  0.55

Epoch:  8
Loss:  0.181773
Accuracy:  0.5625

Epoch:  9
Loss:  0.289508
Accuracy:  0.5625

Epoch:  10
Loss:  0.18628
Accuracy:  0.575

Epoch:  11
Loss:  0.136917
Accuracy:  0.575

Epoch:  12
Loss:  0.445
Accuracy:  0.5625

Epoch:  13
Loss:  0.221793
Accuracy:  0.55

Epoch:  14
Loss:  0.148318
Accuracy:  0.5125

Epoch:  15
Loss:  0.269762
Accuracy:  0.5

Epoch:  16
Loss:  0.447253
Accuracy:  0.5125

Epoch:  17
Loss:  0.286475
Accuracy:  0.5375

Epoch:  18
