In [1]:
import torchvision.transforms as tfms

mean = (0.5225, 0.483, 0.3616)
std = (0.2775, 0.2593, 0.2915)

mean = (0.5,0.5,0.5)
std = (0.25,0.25,0.25)
train_tfms = tfms.Compose([
    #already resized to 300x300
    tfms.RandomResizedCrop(256),
    tfms.RandomHorizontalFlip(),
    tfms.ToTensor(),
    tfms.Normalize(mean,std)
])

val_tfms = tfms.Compose([
    tfms.Resize(300),
    tfms.CenterCrop(256),
    tfms.ToTensor(),
    tfms.Normalize(mean,std)
])

In [2]:
from torchvision.datasets import ImageFolder

root = "./data/hymenoptera/"

train_ds = ImageFolder(root+"train", train_tfms)
val_ds = ImageFolder(root+"val", val_tfms)

img_cls = train_ds.classes
print(img_cls)

['ants', 'bees']


In [3]:
from torch.utils.data.dataloader import DataLoader

train_dl = DataLoader(train_ds, 5,shuffle=True,num_workers=4,pin_memory=True)
val_dl = DataLoader(val_ds, 20,shuffle=False,num_workers=4,pin_memory=True)

In [4]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [5]:
def fit(epochs,model,loss_fn,opt):
    for epoch in range(epochs):
        print("Epoch: ",epoch+1)

        for images,labels in train_dl:
            preds = model(images.to(device))
        
            loss = loss_fn(preds.to(device), labels.to(device))
            loss.backward()
            
            opt.step()
            opt.zero_grad()

        print("Loss: ",round(loss.item(),6))
        
        
        with torch.no_grad():
            correct = 0
            for images,labels in val_dl:
                preds = model(images.to(device))
                
                for i in range(len(preds)):
                    if (preds[i].max()==preds[i][labels[i].item()]):
                        correct += 1
                
            acc = correct/len(val_ds)
            print("Accuracy: ",round(acc*100,2))
            print("")
        if acc>0.80:
            break


In [6]:
from torchvision import models

HymenopteraModel = models.resnet18(pretrained=1)
print(HymenopteraModel.fc)

for param in HymenopteraModel.parameters():
    param.requires_grad = False

num_ftrs = HymenopteraModel.fc.in_features
HymenopteraModel.fc = nn.Linear(num_ftrs,2)
print(HymenopteraModel.fc)

Linear(in_features=512, out_features=1000, bias=True)
Linear(in_features=512, out_features=2, bias=True)


In [7]:
HymenopteraModel = HymenopteraModel.to(device)
CrossEntropy_fn = nn.CrossEntropyLoss()
AdamOpt = torch.optim.Adam(HymenopteraModel.parameters(),lr=0.001)

In [8]:
fit(5,HymenopteraModel,CrossEntropy_fn,AdamOpt)

Epoch:  1


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Loss:  0.368291
Accuracy:  53.75

Epoch:  2
Loss:  0.313351
Accuracy:  48.75

Epoch:  3
Loss:  0.224606
Accuracy:  46.25

Epoch:  4
Loss:  0.486599
Accuracy:  46.25

Epoch:  5
Loss:  0.607188
Accuracy:  50.0

