In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 KB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub
  Downloading huggingface_hub-0.12.0-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.12.0 timm-0.6.12


In [4]:
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import torch.optim as optim
from glob import glob
import torchvision
from PIL import Image
from sklearn.model_selection import train_test_split
import cv2
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

In [5]:
folderpath = "/content/drive/MyDrive/Colab Notebooks/2023/Performance Comparison"
datapath = f"{folderpath}/data"

In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print(device)

cuda:0


In [7]:
label_list = ['Butterfly', 'Cat', 'Chicken', 'Cow', 'Dog', 'Elephant', 'Horse', 'Sheep', 'Spider', 'Squirrel']

In [8]:
class DataProcess():
    def __init__(self, dataType='train'):
        self.dataType = dataType
        self.imgPathList = []
        self.imgCountList = []
        self.totalCount = 0

        for label in label_list:
            image_path = f"{datapath}/{label}"
            image_list = glob(f"{image_path}/*.jpg")
            _len = len(image_list)
            print(label, _len)

            if dataType =='train':
                image_list = image_list[:int(_len * 0.8)]
            elif dataType == 'valid':
                image_list = image_list[int(_len * 0.8):] 

            self.imgPathList.append(image_list)
            count = len(image_list)
            self.imgCountList.append(count)
            self.totalCount += count

    def findImagePath(self, idx):
        label_idx = 0
        for (i, count) in enumerate(self.imgCountList):
            label_idx = i
            if idx < count:
                path = self.imgPathList[i][idx]
                return path, label_idx
            idx -= count
        return self.imgPathList[-1][idx], label_idx

    def getLength(self):
        return self.totalCount
     

In [9]:
train_dataprocess = DataProcess(dataType='train')
valid_dataprocess = DataProcess(dataType='valid')
# test_dataprocess = DataProcess(dataType='test')

Butterfly 422
Cat 440
Chicken 0
Cow 0
Dog 0
Elephant 359
Horse 0
Sheep 376
Spider 322
Squirrel 0
Butterfly 422
Cat 440
Chicken 0
Cow 0
Dog 0
Elephant 359
Horse 0
Sheep 376
Spider 322
Squirrel 0


In [10]:
print(train_dataprocess.getLength(), valid_dataprocess.getLength())

1533 386


In [11]:
vit_model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
vit_model = vit_model.to(device)

In [12]:
config = resolve_data_config({}, model=vit_model)
vit_transform = create_transform(**config)

In [13]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, mode='train', transforms=None):
        self.transforms = transforms
        self.mode = mode
    
    def __getitem__(self, idx):
        if self.mode == 'train':
            path, label = train_dataprocess.findImagePath(idx)
        elif self.mode == 'valid':
            path, label = valid_dataprocess.findImagePath(idx)
        elif self.mode == 'test':
            path, label = test_dataprocess.findImagePath(idx)
        else:
            print('Invalid Mode')
            assert(0)
        image = Image.open(path)
        image = image.convert("RGB")
        # if self.transforms is not None:
        #     image = self.transforms(image)
        image = vit_transform(image)
            
        return (image, label)

    def __len__(self):
        if self.mode == 'train':
            return train_dataprocess.getLength()
        elif self.mode == 'valid':
            return valid_dataprocess.getLength()
        return test_dataprocess.getLength()

In [14]:
transforms = torchvision.transforms.Compose([
                  torchvision.transforms.Resize((224, 224)),
                  torchvision.transforms.ToTensor(),
                  torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [15]:
train_dataset = CustomDataset(mode='train', transforms=transforms)
valid_dataset = CustomDataset(mode='valid', transforms=transforms)
# test_dataset = CustomDataset(mode='test', transforms=transforms)

In [16]:
batch_size = 64

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [17]:
# !pip install timm

In [18]:
learning_rate = 0.001
num_epochs = 1000

vit_optimizer = torch.optim.SGD(vit_model.parameters(), lr=learning_rate)
vit_scheduler = optim.lr_scheduler.LambdaLR(optimizer=vit_optimizer,
                                        lr_lambda=lambda epoch: 0.95 ** epoch,
                                        last_epoch=-1,
                                        verbose=False)

modelpath = f"{folderpath}/model/ViT"

In [19]:
import os.path

if os.path.exists(f"{modelpath}/model_state_dict.pt"):
    vit_model.load_state_dict(torch.load(f"{modelpath}/model_state_dict.pt"))
    vit_optimizer.load_state_dict(torch.load(f"{modelpath}/optim_state_dict.pt"))
    vit_scheduler.load_state_dict(torch.load(f"{modelpath}/scheduler_state_dict.pt"))
    print('Load Complete')
else:
    print('Load Fail')

Load Fail


In [20]:
def getAverage(l):
    return sum(l) / len(l)

In [22]:
f = open(f"{folderpath}/log/ViT.txt", 'w')

In [23]:
import time

start = time.time()
error = nn.MSELoss()
vit_model.train()
count = 0
globalMinLoss = float('inf')
print("-------Running-------")
for epoch in range(num_epochs):
    train_loss_list, valid_loss_list = [], []

    for (images, labels) in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        vit_optimizer.zero_grad()
        outputs = vit_model(images)
        outputs = torch.max(outputs, 1)[1]
 
        outputs = outputs.to(torch.float32)
        labels = labels.to(torch.float32)
        # print(outputs, labels)
        loss = error(outputs, labels)
        loss.requires_grad = True
        loss.backward()
        vit_optimizer.step()
        train_loss_list.append(loss)
        count += 1
    vit_scheduler.step()
    if epoch % 1 == 0:
        with torch.no_grad():
            for (images ,labels) in valid_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = vit_model(images.view(batch_size, 3, 224, 224))
                outputs = torch.max(outputs, 1)[1]
                outputs = outputs.to(torch.float32)
                labels = labels.to(torch.float32)
                loss = error(outputs, labels)
                valid_loss_list.append(loss)

            train_loss = getAverage(train_loss_list)
            valid_loss = getAverage(valid_loss_list)

            epoch_log = f"""{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} || [{epoch}/{num_epochs}], train_loss = {train_loss:.4f}, valid_loss = {valid_loss:.4f}"""
            print(epoch_log)
            f.write(epoch_log)
            f.write('\n')
            if valid_loss < globalMinLoss:
                globalMinLoss = valid_loss
                best_model_state = vit_model.state_dict()
                best_optim_state = vit_optimizer.state_dict()
                best_scheduler_state = vit_scheduler.state_dict()
elapsed = time.time() - start
print(f"End of training, elapsed time : {elapsed // 60} min {elapsed % 60} sec.")
     

-------Running-------
2023-02-05 05:15:49 || [0/1000], train_loss = 21.5503, valid_loss = 22.3828
2023-02-05 05:16:41 || [1/1000], train_loss = 21.7785, valid_loss = 22.3802
2023-02-05 05:17:16 || [2/1000], train_loss = 21.4334, valid_loss = 22.3802
2023-02-05 05:17:49 || [3/1000], train_loss = 21.5476, valid_loss = 22.2240
2023-02-05 05:18:23 || [4/1000], train_loss = 21.5802, valid_loss = 22.4531
2023-02-05 05:18:56 || [5/1000], train_loss = 21.4029, valid_loss = 22.2630
2023-02-05 05:19:29 || [6/1000], train_loss = 21.7092, valid_loss = 22.5052
2023-02-05 05:20:03 || [7/1000], train_loss = 21.5897, valid_loss = 22.4141
2023-02-05 05:20:36 || [8/1000], train_loss = 21.3764, valid_loss = 22.2865
2023-02-05 05:21:09 || [9/1000], train_loss = 21.6590, valid_loss = 22.4141
2023-02-05 05:21:42 || [10/1000], train_loss = 21.5849, valid_loss = 22.2240
2023-02-05 05:22:16 || [11/1000], train_loss = 21.6053, valid_loss = 22.3516
2023-02-05 05:22:49 || [12/1000], train_loss = 21.3947, valid_lo

KeyboardInterrupt: ignored

In [24]:
modelpath = f"{folderpath}/model/ViT"

f"{modelpath}/model_state_dict.pt"

if best_model_state is not None and best_optim_state is not None:
    torch.save(best_model_state, f"{modelpath}/model_state_dict.pt")
    torch.save(best_optim_state, f"{modelpath}/optim_state_dict.pt")
    torch.save(best_scheduler_state, f"{modelpath}/scheduler_state_dict.pt")
    print("Successfully saved.")

Successfully saved.


In [25]:
f.close()