<a href="https://colab.research.google.com/github/O00O297/CCA-AI/blob/master/train_CNN_VGG16.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

import torch.utils.data as data

from PIL import Image
import os
import os.path
from os.path import isfile, join
from os import walk
from os import listdir

In [None]:
use_gpu = torch.cuda.is_available()
if use_gpu:
  print("Using CUDA")
else:
  print("CPU")

In [None]:
def default_loader(path):
	return Image.open(path).convert('RGB')

def default_flist_reader(flist,path):
  imlist = []
  # path = "/content/drive/Shared drives/CCA-AI Slide/"
  with open(flist, 'r') as rf:
    for line in rf.readlines():
      impath, imlabel = line.strip().split()
      dir = path+ "/" +impath
      print(dir)
      sub_dir = [x[2] for x in walk(dir)][0]
      for img in sub_dir:
        img_path =impath + "/" + img
        imlist.append( (img_path, imlabel) )
    print(imlist)
  return imlist 

class ImageFilelist(data.Dataset):
  def __init__(self, root, flist, transform=None, target_transform=None,
               flist_reader=default_flist_reader, loader=default_loader):
    self.root   = root
    self.imlist = flist_reader(flist)
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
  
  def __getitem__(self, index):
    impath, target = self.imlist[index]
    img = self.loader(os.path.join(self.root,impath))
    if self.transform is not None:
      img = self.transform(img)
    if self.target_transform is not None:
      target = self.target_transform(target)

    return img, target

  def __len__(self):
    return len(self.imlist)

In [None]:
train_loaders = torch.utils.data.DataLoader(
    ImageFilelist(root="/content/drive/Shared drives/CCA-AI Slide",
                  flist="/content/drive/Shared drives/CCA-AI Slide/filepath_train.txt",
                  transform=transforms.Compose(
        [transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), 
        ])),
        batch_size=8, shuffle=True,
        num_workers=4
    )
val_loaders = torch.utils.data.DataLoader(
    ImageFilelist(root="/content/drive/Shared drives/CCA-AI Slide",
                  flist="/content/drive/Shared drives/CCA-AI Slide/filepath_val.txt",
                  transform=transforms.Compose(
        [transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), 
        ])),
        batch_size=8, shuffle=True,
        num_workers=4
    )
test_loaders = torch.utils.data.DataLoader(
    ImageFilelist(root="/content/drive/Shared drives/CCA-AI Slide",
                  flist="/content/drive/Shared drives/CCA-AI Slide/filepath_test.txt",
                  transform=transforms.Compose(
        [transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), 
        ])),
        batch_size=8, shuffle=False,
        num_workers=4
    )

In [None]:
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    # plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

def show_databatch(inputs, classes):
    out = torchvision.utils.make_grid(inputs)
    imshow(out, title=[x for x in classes])

inputs, classes = next(iter(val_loaders))
show_databatch(inputs[0], classes[0])

In [None]:
# vgg16 = models.vgg16_bn()
vgg16 = models.vgg16_bn(pretrained=False,num_classes=3)
# vgg16.load_state_dict(torch.load("/content/drive/Shared drives/CCA-AI Slide/Dataset/Model/VGG/vgg16_bn.pth"))

print(vgg16.classifier[6].out_features) # 1000 


# Freeze training for all layers
for param in vgg16.features.parameters():
    param.require_grad = False

# Newly created modules have require_grad=True by default
num_features = vgg16.classifier[6].in_features
features = list(vgg16.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, 3)]) # Add our layer with 4 outputs
vgg16.classifier = nn.Sequential(*features) # Replace the model classifier
print(vgg16)

In [None]:
if use_gpu:
    vgg16.cuda() #.cuda() will move everything to the GPU side
    
criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(vgg16.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [None]:
def train_model(vgg, criterion, optimizer, scheduler, num_epochs=10):
    use_gpu = torch.cuda.is_available()
    since = time.time()
    best_model_wts = copy.deepcopy(vgg.state_dict())
    best_acc = 0.0
    
    avg_loss = 0
    avg_acc = 0
    avg_loss_val = 0
    avg_acc_val = 0
    
    train_batches = len(train_loaders)
    val_batches = len(val_loaders)
    
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs))
        print('-' * 10)
        
        loss_train = 0
        loss_val = 0
        acc_train = 0
        acc_val = 0
        
        vgg.train(True)
        
        for i, (data) in enumerate(train_loaders):
            if i % 100 == 0:
                print("\rTraining batch {}/{}".format(i, train_batches / 2), end='', flush=True)
                
            # Use half training dataset
            if i >= train_batches / 2:
                break
                
            inputs, labels = data
            inputs, labels = np.asarray(inputs, dtype = np.float32), np.asarray(labels, dtype = np.int64)
            inputs, labels = torch.from_numpy(inputs), torch.from_numpy(labels)
            
            if use_gpu:
                inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            else:
                inputs, labels = Variable(inputs), Variable(labels)
            
            optimizer.zero_grad()
            
            outputs = vgg(inputs)
            
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            # loss_train += loss.data[0]
            loss_train += loss.data
            acc_train += torch.sum(preds == labels.data)
            
            del inputs, labels, outputs, preds
            torch.cuda.empty_cache()
        
        print()
        # * 2 as we only used half of the dataset
        avg_loss = loss_train * 2 / float(len(train_loaders))
        avg_acc = acc_train * 2 / float(len(train_loaders))
        
        vgg.train(False)
        vgg.eval()
            
        for i, (data) in enumerate(val_loaders):
            if i % 100 == 0:
                print("\rValidation batch {}/{}".format(i, val_batches), end='', flush=True)
                # print("Val")
                
            inputs, labels = data
            inputs, labels = np.asarray(inputs, dtype = np.float32), np.asarray(labels, dtype = np.int64)
            inputs, labels = torch.from_numpy(inputs), torch.from_numpy(labels)
            
            if use_gpu:
                inputs, labels = Variable(inputs.cuda(), volatile=True), Variable(labels.cuda(), volatile=True)
            else:
                inputs, labels = Variable(inputs, volatile=True), Variable(labels, volatile=True)
            
            optimizer.zero_grad()
            
            outputs = vgg(inputs)
            
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            # loss_val += loss.data[0]
            loss_val += loss.data
            acc_val += torch.sum(preds == labels.data)
            
            del inputs, labels, outputs, preds
            torch.cuda.empty_cache()
        
        avg_loss_val = loss_val / float(len(train_loaders))
        avg_acc_val = acc_val / float(len(val_loaders))
        
        print()
        print("Epoch {} result: ".format(epoch))
        print("Avg loss (train): {:.4f}".format(avg_loss))
        print("Avg acc (train): {:.4f}".format(avg_acc))
        print("Avg loss (val): {:.4f}".format(avg_loss_val))
        print("Avg acc (val): {:.4f}".format(avg_acc_val))
        print('-' * 10)
        print()
        
        if avg_acc_val > best_acc:
            best_acc = avg_acc_val
            best_model_wts = copy.deepcopy(vgg.state_dict())
        
    elapsed_time = time.time() - since
    print()
    print("Training completed in {:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60))
    print("Best acc: {:.4f}".format(best_acc))
    
    vgg.load_state_dict(best_model_wts)
    return vgg

In [None]:
vgg16 = train_model(vgg16, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=5)
status = torch.save(vgg16.state_dict(), '/content/drive/Shared drives/CCA-AI Slide/Slide_preprocess/model3class.pt')
print(status)