In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as f
from torchvision import models
from torch.nn.modules.activation import ReLU
from torchvision.transforms.transforms import Normalize

device = torch.device('cuda:0')

In [None]:
transform = transforms.Compose(
    [ transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ]
)

train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=256),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [None]:
resnet = models.resnet50(pretrained = True)
resnet = resnet.to(device)
for param in resnet.parameters():
  param.required_grad = False
resnet.fc

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


Linear(in_features=2048, out_features=1000, bias=True)

In [None]:
fc_resnet_in_features = resnet.fc.in_features

resnet.fc = nn.Sequential(
    nn.Linear(fc_resnet_in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, 10),
    nn.Softmax(dim=1)
)

resnet = resnet.to(device)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.parameters(), lr= 0.005)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.1)

In [None]:
epochs = 100

In [None]:
trainloader

<torch.utils.data.dataloader.DataLoader at 0x7fc2a9c4c9d0>

In [None]:
val_input = None
random_train = None
split = 0.8

In [None]:
def accuracy_metrics(outputs, labels):
  new_outputs = torch.argmax(outputs, dim = 1)
  new_outputs = new_outputs[new_outputs == labels].shape[0]
  return new_outputs / outputs.shape[0]

In [None]:
save_train_loss = []
save_val_loss = []
save_train_acc = []
save_val_acc = []

In [None]:
len(trainloader)

782

In [None]:
for epoch in range(epochs):
  if epoch == 10:
    for param in resnet.parameters():
      param.required_grad = True
  running_loss = 0.0
  running_loss_val = 0.0
  accuracy = 0.
  accuracy_val = 0.
  count = 0
  for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    split_i = int(split * inputs.shape[0])
    inputs_val = inputs[split_i:]
    labels_val = labels[split_i:]
    inputs = inputs[:split_i]
    labels = labels[:split_i]
    inputs = inputs.to(device)
    labels = labels.to(device)
    inputs_val = inputs_val.to(device)
    labels_val = labels_val.to(device)
    random_train = np.random.permutation(inputs.shape[0])
    inputs = inputs[random_train]
    labels = labels[random_train]
    optimizer.zero_grad()
    outputs = resnet(inputs)
    loss = criterion(outputs, labels)
    with torch.no_grad():
      val_outputs = resnet(inputs_val)
      loss_val = criterion(val_outputs, labels_val)
      count += 1
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
    accuracy += accuracy_metrics(outputs, labels)
    accuracy_val += accuracy_metrics(val_outputs, labels_val)
    running_loss_val += loss_val.item()
    count += 1
    if i == len(trainloader) - 1:
      save_train_loss.append(running_loss / len(trainloader))
      save_val_loss.append(running_loss_val / len(trainloader))
      save_train_acc.append(accuracy / count)
      save_val_acc.append(accuracy_val / count)
      print(f'Epoch {epoch + 1}  -  train loss: {running_loss / len(trainloader):.3f} -- train accuracy: {accuracy / count : .3f} --  val loss: {running_loss_val / len(trainloader):.3f} -- val accuracy: {accuracy_val / count  : .3f}')
      running_loss = 0.0
      running_loss_val = 0.0
      count = 0
      accuracy = 0.
      accuracy_val = 0.
  print('Finished Training')    

In [None]:
#  "Accuracy"
plt.plot(save_train_acc)
plt.plot(save_val_acc)
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
# "Loss"
plt.plot(save_train_loss)
plt.plot(save_val_loss)
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()

In [None]:
running_loss = 0
accuracy = 0
count = 0
for i, data in enumerate(testloader, 0):
  inputs, labels = data
  inputs, labels = inputs.to(device), labels.to(device)
  with torch.no_grad():
    outputs = resnet(inputs)
    loss = criterion(outputs, labels)
  running_loss += loss.item()
  accuracy += accuracy_metrics(outputs, labels)
  count += 1
  if i == len(testloader) - 1:
    print(f'[{epoch + 1}, {i + 1:5d}] test loss: {running_loss / len(trainloader):.3f}')
    print(f'[{epoch + 1}, {i + 1:5d}] test accuracy: {accuracy / count : .3f}')