<a href="https://colab.research.google.com/github/ajuhz/Artificial-Intelligence/blob/master/Transfer_Learning_with_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Note: This is a simplified and more explainable version from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torch.autograd import Variable
from torchvision import datasets, models, transforms
import os
import numpy as np

In [5]:
#Data Augmentation and normalization for training

# Data augmentation and normalization for training
# Just normalization for validation
#class torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)
#transforms.RandomResizedCrop(224) --> A crop of random size (default: of 0.08 to 1.0) of the original size and a 
#random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. 
#This crop is finally resized to given size (224 in this case). 
#transforms.CenterCrop(224)--> Crops the image at the center. 224 is the Desired output size of the crop.
#class torchvision.transforms.Normalize(mean, std)
#Normalize a tensor image with mean and standard deviation. Given mean: (M1,...,Mn) and std: (S1,..,Sn) for n channels, 
#this transform will normalize each channel of the input torch.Tensor i.e. 
#input[channel] = (input[channel] - mean[channel]) / std[channel]
#Parameters:     mean (sequence) – Sequence of means for each channel.
#                std (sequence) – Sequence of standard deviations for each channel.
data_transforms={'train':transforms.Compose([transforms.RandomResizedCrop(224),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
                 'val' : transforms.Compose([transforms.Resize(256),
                                            transforms.CenterCrop(224),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}

In [6]:
!pwd

/content


In [10]:
%cd /content/drive/My Drive/Python

/content/drive/My Drive/Python


In [11]:
data_dir = 'hymenoptera_data'

In [12]:
#Create a dictionary that contains the information of the images in both the training and validation set
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train', 'val']}
#Create a dictionary that contians the data loader
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=4,
                                              shuffle=True) for x in ['train', 'val']}

In [20]:
print(image_datasets['train'].classes)
print(image_datasets['train'].imgs)


['ants', 'bees']
[('hymenoptera_data/train/ants/0013035.jpg', 0), ('hymenoptera_data/train/ants/1030023514_aad5c608f9.jpg', 0), ('hymenoptera_data/train/ants/1095476100_3906d8afde.jpg', 0), ('hymenoptera_data/train/ants/1099452230_d1949d3250.jpg', 0), ('hymenoptera_data/train/ants/116570827_e9c126745d.jpg', 0), ('hymenoptera_data/train/ants/1225872729_6f0856588f.jpg', 0), ('hymenoptera_data/train/ants/1262877379_64fcada201.jpg', 0), ('hymenoptera_data/train/ants/1269756697_0bce92cdab.jpg', 0), ('hymenoptera_data/train/ants/1286984635_5119e80de1.jpg', 0), ('hymenoptera_data/train/ants/132478121_2a430adea2.jpg', 0), ('hymenoptera_data/train/ants/1360291657_dc248c5eea.jpg', 0), ('hymenoptera_data/train/ants/1368913450_e146e2fb6d.jpg', 0), ('hymenoptera_data/train/ants/1473187633_63ccaacea6.jpg', 0), ('hymenoptera_data/train/ants/148715752_302c84f5a4.jpg', 0), ('hymenoptera_data/train/ants/1489674356_09d48dde0a.jpg', 0), ('hymenoptera_data/train/ants/149244013_c529578289.jpg', 0), ('hymeno

In [35]:
next(iter(dataloaders['train']))

[tensor([[[[-0.0458, -0.0458, -0.0458,  ...,  1.0844,  1.0502,  1.0673],
           [-0.0116, -0.0116, -0.0287,  ...,  1.0844,  1.0502,  1.0331],
           [ 0.0056,  0.0056, -0.0116,  ...,  1.0844,  1.0502,  0.9817],
           ...,
           [-0.1999, -0.1314, -0.1828,  ...,  0.8447,  0.8276,  0.8789],
           [-0.1999, -0.1828, -0.1828,  ...,  0.7762,  0.7591,  0.8104],
           [-0.2171, -0.2171, -0.1828,  ...,  0.7419,  0.7077,  0.7248]],
 
          [[ 0.1001,  0.1001,  0.1001,  ...,  1.2906,  1.2556,  1.2556],
           [ 0.1176,  0.1176,  0.1001,  ...,  1.2906,  1.2556,  1.2206],
           [ 0.1176,  0.1176,  0.1001,  ...,  1.2906,  1.2556,  1.1681],
           ...,
           [-0.2500, -0.1800, -0.2325,  ...,  0.9230,  0.9055,  1.0280],
           [-0.2500, -0.2325, -0.2325,  ...,  0.8704,  0.8354,  0.9230],
           [-0.2675, -0.2675, -0.2325,  ...,  0.8354,  0.8004,  0.8179]],
 
          [[ 0.0082,  0.0082,  0.0082,  ...,  1.1759,  1.1411,  1.1062],
           [ 

In [36]:
#load trained ResNet
model_conv=torchvision.models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [37]:
#freezing all layers of the model
for params in model_conv.parameters():
  params.requires_grad=False

In [38]:
model_conv

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  