In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import os
import copy
from torchvision.datasets.utils import download_url
import tarfile
from random import sample
import glob
import shutil
from torchvision.transforms import Compose
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib
from tqdm import tqdm

from torchvision.utils import make_grid

%matplotlib inline





data = download_url("https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz",".")

with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')




if os.path.exists("/content/data/cifar10/validate") is False:

  os.makedirs("/content/data/cifar10/validate")

  os.makedirs("/content/data/cifar10/validate/airplane")

  os.makedirs("/content/data/cifar10/validate/automobile")

  os.makedirs("/content/data/cifar10/validate/bird")

  os.makedirs("/content/data/cifar10/validate/cat")

  os.makedirs("/content/data/cifar10/validate/deer")

  os.makedirs("/content/data/cifar10/validate/dog")

  os.makedirs("/content/data/cifar10/validate/frog")

  os.makedirs("/content/data/cifar10/validate/horse")

  os.makedirs("/content/data/cifar10/validate/ship")

  os.makedirs("/content/data/cifar10/validate/truck")    

Downloading https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz to ./cifar10.tgz


  0%|          | 0/135107811 [00:00<?, ?it/s]

In [None]:
for i in sample(glob.glob("/content/data/cifar10/train/airplane/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/airplane")

for i in sample(glob.glob("/content/data/cifar10/train/automobile/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/automobile")

for i in sample(glob.glob("/content/data/cifar10/train/bird/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/bird")

for i in sample(glob.glob("/content/data/cifar10/train/cat/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/cat")

for i in sample(glob.glob("/content/data/cifar10/train/deer/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/deer")

for i in sample(glob.glob("/content/data/cifar10/train/dog/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/dog")

for i in sample(glob.glob("/content/data/cifar10/train/frog/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/frog")

for i in sample(glob.glob("/content/data/cifar10/train/horse/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/horse")

for i in sample(glob.glob("/content/data/cifar10/train/ship/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/ship")

for i in sample(glob.glob("/content/data/cifar10/train/truck/*.png"),500):
  shutil.move(i,"/content/data/cifar10/validate/truck")


In [None]:
# Image transformations

    # Train uses data augmentation
    
training_augment = transforms.Compose([
        transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),  # Image net standards
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])  # Imagenet standards
    ])
    # Validation does not use augmentation
    
val_augment =   transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])



train_ds = ImageFolder("/content/data/cifar10/train", training_augment)
val_ds = ImageFolder("/content/data/cifar10/validate", val_augment)

train_dl = DataLoader(train_ds, batch_size = 32 , shuffle = True)
val_dl = DataLoader(val_ds,batch_size = 32,shuffle = False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.resnet34(pretrained = True)

model.fc = nn.Linear(model.fc.in_features,10)
model = model.to(device)
print(model)      

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))



def training_step(model, batch):
        images, labels = batch
        images,labels = images.to(device), labels.to(device)
        out = model(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels)  # Calculate loss
        return loss

def validation_step(model, batch):
        images, labels = batch
        images,labels = images.to(device), labels.to(device)
        out = model(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}

def validation_epoch_end(outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

def epoch_end( epoch, result):
        print("Epoch [{}],{} train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, "last_lr: {:.5f},".format(result['lrs'][-1]) if 'lrs' in result else '', 
            result['train_loss'], result['val_loss'], result['val_acc']))
        

@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [validation_step(model,batch) for batch in val_loader]
    return validation_epoch_end(outputs)



def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader,
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []

    # Set up custom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs,
                                                steps_per_epoch=len(train_loader))

    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_losses = []
        lrs = []
        for batch in tqdm(train_loader):
            loss = training_step(model,batch)
            train_losses.append(loss)
            loss.backward()

            # Gradient clipping
            if grad_clip:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()

        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(epoch, result)
        history.append(result)
    return history
 

           

In [None]:
epochs = 15
max_lr = 0.001
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam

history = [evaluate(model, val_dl)]
history


[{'val_acc': 0.1166401281952858, 'val_loss': 2.6075894832611084}]

In [None]:
history += fit_one_cycle(epochs, max_lr, model, train_dl, val_dl, 
                         grad_clip=grad_clip, 
                         weight_decay=weight_decay, 
                         opt_func=opt_func) 

100%|██████████| 1407/1407 [10:18<00:00,  2.28it/s]


Epoch [0],last_lr: 0.00015, train_loss: 0.2714, val_loss: 0.2020, val_acc: 0.9293


100%|██████████| 1407/1407 [10:18<00:00,  2.27it/s]


Epoch [1],last_lr: 0.00044, train_loss: 0.3093, val_loss: 0.3467, val_acc: 0.8812


100%|██████████| 1407/1407 [10:18<00:00,  2.27it/s]


Epoch [2],last_lr: 0.00076, train_loss: 0.4327, val_loss: 0.4544, val_acc: 0.8485


100%|██████████| 1407/1407 [10:18<00:00,  2.28it/s]


Epoch [3],last_lr: 0.00097, train_loss: 0.4715, val_loss: 0.4090, val_acc: 0.8613


100%|██████████| 1407/1407 [10:17<00:00,  2.28it/s]


Epoch [4],last_lr: 0.00099, train_loss: 0.4646, val_loss: 0.3766, val_acc: 0.8722


100%|██████████| 1407/1407 [10:16<00:00,  2.28it/s]


Epoch [5],last_lr: 0.00095, train_loss: 0.4243, val_loss: 0.3592, val_acc: 0.8766


100%|██████████| 1407/1407 [10:15<00:00,  2.29it/s]


Epoch [6],last_lr: 0.00087, train_loss: 0.3924, val_loss: 0.3377, val_acc: 0.8891


100%|██████████| 1407/1407 [10:14<00:00,  2.29it/s]


Epoch [7],last_lr: 0.00075, train_loss: 0.3443, val_loss: 0.3046, val_acc: 0.9001


100%|██████████| 1407/1407 [10:13<00:00,  2.29it/s]


Epoch [8],last_lr: 0.00061, train_loss: 0.3019, val_loss: 0.2922, val_acc: 0.9005


100%|██████████| 1407/1407 [10:12<00:00,  2.30it/s]


Epoch [9],last_lr: 0.00046, train_loss: 0.2499, val_loss: 0.2301, val_acc: 0.9204


100%|██████████| 1407/1407 [10:11<00:00,  2.30it/s]


Epoch [10],last_lr: 0.00032, train_loss: 0.2042, val_loss: 0.1916, val_acc: 0.9361


100%|██████████| 1407/1407 [10:12<00:00,  2.30it/s]


Epoch [11],last_lr: 0.00019, train_loss: 0.1548, val_loss: 0.1753, val_acc: 0.9407


100%|██████████| 1407/1407 [10:11<00:00,  2.30it/s]


Epoch [12],last_lr: 0.00009, train_loss: 0.1118, val_loss: 0.1617, val_acc: 0.9469


100%|██████████| 1407/1407 [10:13<00:00,  2.29it/s]


Epoch [13],last_lr: 0.00002, train_loss: 0.0841, val_loss: 0.1507, val_acc: 0.9504


100%|██████████| 1407/1407 [10:12<00:00,  2.30it/s]


Epoch [14],last_lr: 0.00000, train_loss: 0.0715, val_loss: 0.1523, val_acc: 0.9502
