In [10]:
import os
import numpy as np
import matplotlib.pyplot as plt 
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset,random_split
from Data_Loader import CustomDataset
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingWarmRestarts
from torchvision import transforms, models 
from model import Inception
from Swish import Swish 
from log_chk import Checkpoint,logger   

In [4]:
custom_dataset = CustomDataset()
validation_size = int(len(custom_dataset) * 0.2)
train_dataset, val_dataset = random_split(custom_dataset, [len(custom_dataset) - validation_size, validation_size])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) 
valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=True)     

In [5]:
def fit(clf,
        train_loader,
        optimizer,
        criterian,
        scheduler):

    clf.train()
    training_loss_running = 0
    training_correct_running = 0
    total = 0
    counter = 0
    for i,data in enumerate(train_loader):
        counter += 1
        data,label = data[0],data[1]
        #data = data.reshape(-1,28*28)
        total += label.size(0)
        optimizer.zero_grad()
        out,aux1,aux2 = clf(data)            
        loss = criterian(out, label)                
        training_loss_running += loss.item()
        _,pred = torch.max(out.data,1)
        training_correct_running += (pred == label).sum().item()
        loss.backward()
        optimizer.step()
    scheduler.step() 
    train_loss = training_loss_running / counter
    train_accuracy = 100. * training_correct_running / total
    return train_loss, train_accuracy     


In [6]:
def validation (clf,validation_loader,criterian,epoch):
    clf.eval()
    valid_loss_running = 0
    valid_acc_running = 0
    total = 0
    counter = 0
    for i,data in enumerate(validation_loader):
        counter += 1
        data,label = data[0],data[1]
        #data = data.reshape(-1,28*28)
        total += label.size(0)
        out,aux1,aux2 = clf(data)
        loss = criterian(out, label)
        valid_loss_running += loss.item()
        _,pred = torch.max(out.data,1)
        valid_acc_running += (pred == label).sum().item()

    valid_loss = valid_loss_running / counter
    valid_acc = 100. * valid_acc_running / total  
    chk.save(valid_acc,'chk',epoch,clf)
      

    return valid_loss,valid_acc

        


In [7]:
def train (hyparam,train_loader,val_loader):

    clf = Inception()
    #optimizer = torch.optim.Adam(clf.parameters(), lr =hyparam['lr'])
    optimizer = torch.optim.SGD(clf.parameters(), lr=hyparam['lr'],momentum=0.9, weight_decay=0.001)
    criterian = nn.CrossEntropyLoss()
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1)

    train_loss =[]
    train_acc =[]
    val_loss =[]
    val_acc = []
    
    for epoch in range(hyparam['epoch']):
        print(f"Epoch {epoch+1} of {hyparam['epoch']}")
        training_loss,training_acc = fit(clf,train_loader,optimizer,criterian,scheduler)
        validation_loss,validation_acc = validation(clf,val_loader,criterian,epoch)

        train_loss.append(training_loss)
        train_acc.append(training_acc)

        val_loss.append(validation_loss)
        val_acc.append(validation_acc)

        

        logger.info(f" Epoch: {epoch + 1}, Train Loss: {training_loss:.4f}, Train Acc: {training_acc:.2f},\
         Val Loss: {validation_loss:.4f}, Val Acc: {validation_acc:.2f}")
       
        print(f"Train Loss: {training_loss:.4f}, Train Acc: {training_acc:.2f},\
         Val Loss: {validation_loss:.4f}, Val Acc: {validation_acc:.2f}")



      
    return clf,train_loss,train_acc,val_loss,val_acc
        



In [17]:
chk = Checkpoint()

In [9]:
hparams = {'batch_size': 256, 'lr': 0.6e-1, 'epoch': 50} #6e-4      
clf,train_loss,train_acc,val_loss,val_acc = train(hparams,train_loader,valid_loader)  

Epoch 1 of 50
Train Loss: 3.2520, Train Acc: 51.55,         Val Loss: 4.5937, Val Acc: 53.35
Epoch 2 of 50
Train Loss: 0.4217, Train Acc: 89.02,         Val Loss: 0.0804, Val Acc: 97.69
Epoch 3 of 50
Train Loss: 0.1287, Train Acc: 96.23,         Val Loss: 0.0322, Val Acc: 99.50
Epoch 4 of 50


KeyboardInterrupt: 