<a href="https://colab.research.google.com/github/TongleiChen/sketch_to_image/blob/main/COSC576_project_early_Xmas.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# @title import library
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import models

pre-trained model reference: https://debuggercafe.com/transfer-learning-with-pytorch/

In [9]:
# @title IMAGE CLASSIFICATION
import numpy as np
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler


import random
import pandas as pd

from torchvision import transforms, utils

class ImageDataset(Dataset):

    def __init__(self, img_dir, animal_category, image_size = 255, class_size = 100,transform = False):
        """
        Args:
            sketch_dir (string): Directory to all the sketch images.
            realworld_dir (string): Directory to all the real world images.
            animal_category: list to fruit catogory
            class_size: Num of images in each category
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.img_dir = img_dir
        self.transform = transform
        self.animal_category = animal_category
        self.class_size = class_size
        self.data_dict = dict(np.load(img_dir,allow_pickle=True))
        self.image_size = image_size
       
        self.transform_img = transforms.Compose([transforms.ToPILImage(),
                                                transforms.Resize((image_size,image_size)),
                                                transforms.ToTensor(),])
                                              # transforms.Normalize( mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
        self.transform_label = transforms.Compose([transforms.ToTensor()])
        
    def __len__(self):
        return self.class_size * len(self.animal_category)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        class_index = int(idx // self.class_size)
        category =  self.animal_category[class_index]
        category_idx = int(idx % self.class_size)
        label = np.zeros((len(self.animal_category), 1))
        label[class_index] = 1
        
#         label = class_index
        image_ary =  self.data_dict[category][category_idx]
        sample = {'image': image_ary, 'label': label}
        if self.transform:
            sample['image'] = self.transform_img(sample['image'])
            sample['label'] = self.transform_label(sample['label'])
        return sample
        

In [10]:
# @title Load image data
QURIES = ["bear","camel","cat","dog","elephant","frog","lion","panda","rabbit","squirrel"]  

train_realworld_dir = "/content/drive/MyDrive/kaggle/imagenet/all_images.npz"
train_realworld = ImageDataset(train_realworld_dir, QURIES, image_size = 64, class_size = 500,transform = True)
train_loader = DataLoader(train_realworld, batch_size=128, shuffle=True, pin_memory=True)

test_realworld_dir = "/content/drive/MyDrive/kaggle/imagenet/test_images.npz"
test_realworld = ImageDataset(test_realworld_dir, QURIES, image_size = 64, class_size = 10,transform = True)
test_loader_realworld = DataLoader(test_realworld, batch_size=128, shuffle=True, pin_memory=True)

In [6]:
image_model = models.vgg16(pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_model.to(device)
print(image_model)



VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [7]:
image_model.classifier[6].out_features = 10
# freeze convolution weights
for param in image_model.features.parameters():
    param.requires_grad = False

In [12]:
# Loss and optimizer
num_classes = 10
num_epochs = 30
batch_size = 64
learning_rate = 0.01 #0.03
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(image_model.classifier.parameters(), lr=learning_rate, weight_decay = 0.007, momentum = 0.9)  
LOAD = False
start_epoch =0
train_realworld_acc = []
train_realworld_loss = []
val_realworld_acc = []
val_realworld_loss = []
import gc
total_step = len(train_loader)

for epoch in range(start_epoch,start_epoch+num_epochs):
    for data in train_loader:  
        # Move tensors to the configured device
        images = data['image'].to(device)
        labels = data['label'].to(device)
        correct_t = 0
        total_t = 0
        # Forward pass
        image_model.train()
        outputs = image_model(images)
        loss = criterion(outputs, torch.argmax(labels.squeeze(),dim = 1))
        _, predicted = outputs.squeeze().topk(1, dim=1, largest=True, sorted=True)
        
        total_t += labels.size(0)
        correct_t +=  (predicted.squeeze() == torch.argmax(labels.squeeze(),dim = 1)).sum().item()
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()


        train_realworld_loss.append(loss.item())
    train_realworld_acc.append(correct_t/total_t)
    print ('Epoch [{}/{}], Training Loss: {:.4f}' 
                   .format(epoch+1,start_epoch+num_epochs, loss.item()))
            
    # Validation
    
    with torch.no_grad():
        image_model.eval()
        correct_v = 0
        total_v = 0
        for data in test_loader_realworld:
            images = data['image'].to(device)
            labels = data['label'].to(device)
        

            outputs = image_model(images)
            loss_v = criterion(outputs, torch.argmax(labels.squeeze(),dim = 1))
   
            _, predicted = outputs.squeeze().topk(1, dim=1, largest=True, sorted=True)
        
            total_v += labels.size(0)
            correct_v +=  (predicted.squeeze() == torch.argmax(labels.squeeze(),dim = 1)).sum().item()
            
            del images, labels, outputs
        val_realworld_loss.append(loss_v.item())
        val_realworld_acc.append(correct_v/total_v)
        print(f"Validation accuracy: {(correct_v/total_v):.3f}")
       
print(f"train_acc:{(np.mean([v for v in train_realworld_acc])):.3f},val_acc:{(np.mean([v for v in val_realworld_acc])):.3f}")
    

Epoch [1/30], Training Loss: 0.7527
Validation accuracy: 0.580
Epoch [2/30], Training Loss: 0.8536
Validation accuracy: 0.630
Epoch [3/30], Training Loss: 0.9809
Validation accuracy: 0.680
Epoch [4/30], Training Loss: 1.2078
Validation accuracy: 0.630
Epoch [5/30], Training Loss: 0.4609
Validation accuracy: 0.620
Epoch [6/30], Training Loss: 0.7744
Validation accuracy: 0.680
Epoch [7/30], Training Loss: 0.6792
Validation accuracy: 0.630
Epoch [8/30], Training Loss: 1.7265
Validation accuracy: 0.700
Epoch [9/30], Training Loss: 0.1999
Validation accuracy: 0.700
Epoch [10/30], Training Loss: 0.1726
Validation accuracy: 0.640
Epoch [11/30], Training Loss: 0.2447
Validation accuracy: 0.690
Epoch [12/30], Training Loss: 0.6001
Validation accuracy: 0.680
Epoch [13/30], Training Loss: 0.3548
Validation accuracy: 0.660
Epoch [14/30], Training Loss: 0.2136
Validation accuracy: 0.690
Epoch [15/30], Training Loss: 0.5625
Validation accuracy: 0.690
Epoch [16/30], Training Loss: 0.2110
Validation a