In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')


In [None]:
!pip install tqdm
import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
import matplotlib.pyplot as plt

%matplotlib inline
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms
from tqdm import tqdm

In [None]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [None]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
class ImageTransform():
  def __init__(self, resize, mean, seed):
    self.data_transform = {
        'train': transforms.Compose([
                                     transforms.RandomResizedCrop(resize, scale= (0.5, 1.0)),
                                     transforms.RandomHorizontalFlip(0.7),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean, std)
        ]),
        'val': transforms.Compose([
                                     transforms.Resize(resize),
                                     transforms.CenterCrop(resize),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean, std)
        ])
    }

  def __call__(self, img, phase = 'train'):
    return self.data_transform[phase](img)

In [None]:
img_file_path = "/content/gdrive/MyDrive/ML/Py_torch/data/Hinh-anh-cun-con-de-thuong-cute-lam-hinh-nen-dep-8.jpg"

In [None]:
img = Image.open(img_file_path)

plt.imshow(img)
plt.show()

size = 224
mean = (0.485, 0.456, 0.406)
std= (0.229, 0.224, 0.225) 

transform = ImageTransform(size, mean, std)
img_transformed = transform(img, 'train')

#(channel, height, width) -> (height, weight, channel) -> (0,1)

img_transformed = img_transformed.numpy().transpose(1, 2, 0)
img_transformed = np.clip(img_transformed, 0, 1)
plt.imshow(img_transformed)
plt.show()

In [None]:
def make_datapath_list(phase = 'train'):
  root_path = "/content/gdrive/MyDrive/ML/Py_torch/data/hymenoptera_data/"
  target_path = osp.join(root_path + phase + "/**/*.jpg")
  # print(target_path)

  path_list = []

  for path in glob.glob(target_path):
    path_list.append(path)

  return path_list


In [None]:
path_list = make_datapath_list('train')
print(len(path_list))

In [None]:
path_list[: 10]

In [None]:
train_list = make_datapath_list('train')
val_list = make_datapath_list('val')

In [None]:
class MyDataSet(data.Dataset):
  def __init__(self, file_list, transform=None, phase= 'train'):
    self.file_list = file_list
    self.transform = transform
    self.phase = phase
    
  def __len__(self):
    return len(self.file_list)

  def __getitem__(self, idx):
    img_path = self.file_list[idx]
    img = Image.open(img_path)

    img_transformed = self.transform(img, self.phase)

    label = img_path.split('/')[-2]
    if label == 'ants':
      label = 0
    elif label == 'bees':
      label = 1

    return img_transformed, label

In [None]:
train_dataset = MyDataSet(train_list, transform= ImageTransform(size, mean, std), phase= 'train')
val_dataset = MyDataSet(val_list, transform= ImageTransform(size, mean, std), phase= 'val')

In [None]:
index = 0
print(train_dataset.__len__())
img, img_label = train_dataset.__getitem__(index)
print(img.shape)
print(img_label)

In [None]:
batch_size = 4
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size, shuffle=False)

dataloader_dict = {'train': train_dataloader, 'val': val_dataloader}

In [None]:
batch_iterator = iter(dataloader_dict['train'])
inputs, labels = next(batch_iterator)

In [None]:
print(inputs.size())
print(labels)

Network

In [None]:
use_pretrained = True
net = models.vgg16(pretrained = use_pretrained)
# print(net)

In [None]:
net.classifier[6] = nn.Linear(in_features=4096, out_features=2)

In [None]:
# print(net)

In [None]:
#setting mode
net = net.train()

Loss


In [None]:
criterior = nn.CrossEntropyLoss()

Optimizer

In [None]:
params_to_update = []

update_params_name = ["classifier.6.weight", "classifier.6.bias"]

for name, param in net.named_parameters():
  if name in update_params_name:
    param.requires_grad = True #save gradient in Network
    params_to_update.append(param)
  else:
    param.requires_grad = False

print(params_to_update)


In [None]:
optimizer = optim.SGD(params=params_to_update, lr=0.001, momentum=0.9)

In [None]:
def train_model(net, dataloader_dict, criterior, optimizer, num_epochs):
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs))
        
        for phase in ["train", "val"]:
            if phase == "train":
                net.train()
            else:
                net.eval()
                
            epoch_loss = 0.0
            epoch_corrects = 0
            
            if (epoch == 0) and (phase == "train"):
                continue
            for inputs, labels in tqdm(dataloader_dict[phase]):
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == "train"):
                    outputs = net(inputs)
                    loss = criterior(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        
                    epoch_loss += loss.item()*inputs.size(0)
                    epoch_corrects += torch.sum(preds==labels.data)
                    
            
            epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset)
            epoch_accuracy = epoch_corrects.double() / len(dataloader_dict[phase].dataset)
                    
            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_accuracy))

In [None]:
num_epoch = 8
train_model(net, dataloader_dict, criterior, optimizer, num_epoch)

In [None]:
net.save('net.h5')