In [None]:
## Importing Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
import torchvision

In [None]:
## Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import sys
class Identity(nn.Module):
  def __init__(self):
    super(Identity,self).__init__()

  def forward(self,x):
    return x

# Load pretrain model & modify it
model = torchvision.models.vgg16(pretrained=True)
# Fine Tuning
for param in model.parameters():
  param.requires_grad = False

model.avgpool = Identity()
model.classifier = nn.Sequential(nn.Linear(512,100),
                                 nn.ReLU(),
                                 nn.Linear(100,10))

model.to(device)
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [None]:
## Hyperparameters
in_channels=3
num_classes = 10
learning_rate = 1e-3
batch_size = 1024
num_epochs = 5

In [None]:
## Load Data
train_dataset = datasets.CIFAR10(root='dataset/',train=True,transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)

test_dataset = datasets.CIFAR10(root='dataset/',train=False,transform=transforms.ToTensor(),download=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
## Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=learning_rate)

In [None]:
## Train Network
for epoch in range(num_epochs):
  for batch_idx,(data,targets) in enumerate(tqdm(train_loader)):
    # Get data to cuda
    data = data.to(device=device)
    targets = targets.to(device=device)

    # Forward
    scores = model(data)
    loss = criterion(scores,targets)

    # Backward
    optimizer.zero_grad()
    loss.backward()

    # Gradient descent or Adam step size
    optimizer.step()

100%|██████████| 49/49 [00:05<00:00,  8.52it/s]
100%|██████████| 49/49 [00:05<00:00,  8.61it/s]
100%|██████████| 49/49 [00:05<00:00,  8.54it/s]
100%|██████████| 49/49 [00:05<00:00,  8.40it/s]
100%|██████████| 49/49 [00:05<00:00,  8.35it/s]


In [None]:
## Check Accuracy on training & test to see how good our model 
def check_accuracy(loader,model):
  num_correct = 0
  num_samples = 0
  model.eval()

  with torch.no_grad():
    for x,y  in loader:
      x = x.to(device=device)
      y = y.to(device=device)

      # Forward
      scores = model(x)
      #64x10
      _,predictions = scores.max(1) #(max values for 2nd dim)/will output indices
      
      num_correct += (predictions == y).sum()
      num_samples += predictions.size(0) #Size of 1st dimension

  model.train()
  return num_correct/num_samples

print(f'Accuracy on training data: {check_accuracy(train_loader, model)*100:.2f}')
print(f"Accuracy on test set: {check_accuracy(test_loader, model)*100:.2f}")

Accuracy on training data: 62.60
Accuracy on test set: 60.89
