In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as f
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split

import matplotlib.pyplot as plt
import numpy as np


In [2]:
vgg_models = {
    'A': [64, 'Max', 128, 'Max', 256,256, 'Max', 512,512, 'Max', 512,512, 'Max'],

    'B': [64,64, 'Max', 128,128, 'Max', 256, 256, 'Max', 512,512, 'Max', 512,512, 'Max'],

    'D': [64,64, 'Max', 128,128, 'Max', 256,256,256, 'Max', 512,512,512, 'Max', 512,512,512, 'Max'],

    'E': [64,64, 'Max', 128,128, 'Max', 256,256,256,256, 'Max', 512,512,512,512, 'Max', 512,512,512,512, 'Max']
}

# Model

In [3]:
class VGG(nn.Module):
    def __init__(self, in_channels=3, num_classes=10, architecture=None, init_weights = True):
      super(VGG, self).__init__()

      self.in_channels = in_channels

      self.convnet = self.make_conv_architecture(architecture)
      self.fc_layers = nn.Sequential(
          nn.Linear(512*7*7, 4096),
          nn.ReLU(),
          nn.Dropout(p=0.5),
          nn.Linear(4096, 4096),
          nn.ReLU(),
          nn.Dropout(p=0.5),
          nn.Linear(4096, num_classes)
      )
      if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
      x = self.convnet(x)
      x = x.view(-1, 512*7*7)
      x = self.fc_layers(x)
      return x

    def make_conv_architecture(self, model_architecture):
      layers = []
      in_channels = self.in_channels
      for x in model_architecture:
        if type(x) == int:
          out_channels = x
          layers += [nn.Conv2d(in_channels, out_channels, 3, 1, 1)]
          layers += [nn.ReLU()]
          in_channels = out_channels
        if type(x) == str:
          layers += [nn.MaxPool2d((2,2), 2)]
      return nn.Sequential(*layers)




In [19]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCH = 70
PATH = "vgg_checkpoint.pth"
BATCH_SIZE = 64
LR = 0.0001

In [None]:
!unzip /content/drive/MyDrive/data/Mammals_Images.zip -d /content

In [6]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder(root='/content/mammals', transform = transform)


In [7]:
train_size = int(0.96*len(dataset))
test_size = int(len(dataset)-train_size)
train_size,test_size

(13200, 551)

In [8]:
train_dataset, test_dataset = random_split(dataset , [train_size, test_size ])

In [9]:
train_loader = DataLoader(train_dataset , batch_size= BATCH_SIZE,shuffle = True, num_workers = 2, pin_memory = True, drop_last = True)
test_loader = DataLoader(test_dataset , batch_size= 32,shuffle = True)

In [10]:
classes = train_dataset.dataset.classes
num_classes=len(classes)
print(classes)
print("\n")
print(num_classes)

['african_elephant', 'alpaca', 'american_bison', 'anteater', 'arctic_fox', 'armadillo', 'baboon', 'badger', 'blue_whale', 'brown_bear', 'camel', 'dolphin', 'giraffe', 'groundhog', 'highland_cattle', 'horse', 'jackal', 'kangaroo', 'koala', 'manatee', 'mongoose', 'mountain_goat', 'opossum', 'orangutan', 'otter', 'polar_bear', 'porcupine', 'red_panda', 'rhinoceros', 'sea_lion', 'seal', 'snow_leopard', 'squirrel', 'sugar_glider', 'tapir', 'vampire_bat', 'vicuna', 'walrus', 'warthog', 'water_buffalo', 'weasel', 'wildebeest', 'wombat', 'yak', 'zebra']


45


In [20]:
model = VGG(num_classes = 45,architecture= vgg_models['D'] ).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr= LR)
criterion = nn.CrossEntropyLoss()

In [14]:
from torchsummary import summary
summary(model, (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256,

In [None]:
best_accuracy = 0.0
for epoch in range(EPOCH):
  epoch_loss = 0.0
  correct_prediction = 0
  total_samples = 0
  for data,targets in tqdm(train_loader):
      data = data.to(DEVICE)
      targets = targets.to(DEVICE)

      optimizer.zero_grad()
      scores = model(data)
      loss = criterion(scores, targets)

      loss.backward()
      optimizer.step()

      epoch_loss += loss.item()
      _,pred = torch.max(scores, 1)
      correct_prediction += torch.sum(pred == targets).item()
      total_samples += targets.size(0)
  epoch_accuracy = correct_prediction / total_samples
  epoch_loss /= len(train_loader)

  print('Epoch: {} \tLoss: {:.4f} \tAcc: {:.4f}'.format(epoch + 1, epoch_loss, epoch_accuracy))
  if epoch_accuracy > best_accuracy:
        best_accuracy = epoch_accuracy
        torch.save({'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_accuracy': best_accuracy}, PATH)


100%|██████████| 206/206 [02:59<00:00,  1.15it/s]


Epoch: 1 	Loss: 3.6721 	Acc: 0.0425


100%|██████████| 206/206 [02:59<00:00,  1.15it/s]


Epoch: 2 	Loss: 3.4425 	Acc: 0.0835


100%|██████████| 206/206 [02:59<00:00,  1.14it/s]


Epoch: 3 	Loss: 3.2553 	Acc: 0.1279


100%|██████████| 206/206 [02:59<00:00,  1.15it/s]


Epoch: 4 	Loss: 3.0752 	Acc: 0.1715


100%|██████████| 206/206 [03:00<00:00,  1.14it/s]


Epoch: 5 	Loss: 2.8303 	Acc: 0.2330


100%|██████████| 206/206 [03:00<00:00,  1.14it/s]


Epoch: 6 	Loss: 2.6036 	Acc: 0.2868


100%|██████████| 206/206 [03:00<00:00,  1.14it/s]


Epoch: 7 	Loss: 2.3167 	Acc: 0.3528


100%|██████████| 206/206 [03:00<00:00,  1.14it/s]


Epoch: 8 	Loss: 1.9874 	Acc: 0.4430


100%|██████████| 206/206 [03:00<00:00,  1.14it/s]


Epoch: 9 	Loss: 1.5113 	Acc: 0.5670


100%|██████████| 206/206 [03:00<00:00,  1.14it/s]


Epoch: 10 	Loss: 0.9452 	Acc: 0.7175


100%|██████████| 206/206 [03:01<00:00,  1.14it/s]


Epoch: 11 	Loss: 0.5128 	Acc: 0.8461


100%|██████████| 206/206 [03:01<00:00,  1.14it/s]


Epoch: 12 	Loss: 0.2920 	Acc: 0.9140


100%|██████████| 206/206 [03:00<00:00,  1.14it/s]


Epoch: 13 	Loss: 0.1988 	Acc: 0.9427


 78%|███████▊  | 161/206 [02:21<00:39,  1.15it/s]