In [1]:
from collections import namedtuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda import device
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import pandas as pd
import os
from util.train_util import train, evaluate
from PIL import Image
# from transformers import ViTForImageClassification
from torchinfo import summary
from model.vision_transformer import VisionTransformer
from torch.cuda import is_available
from util.train_util import train_and_evaluate

In [2]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.annotations.iloc[idx, 0]+'.jpg')
        image = Image.open(img_name)
        y_label = torch.tensor(self.annotations.iloc[idx, 1:].values.astype(float))
        
        if self.transform:
            image = self.transform(image)

        return image, y_label

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_set = CustomDataset(csv_file='train.csv', root_dir='images', transform=transform)
train_loader = DataLoader(dataset=train_set, batch_size=16, shuffle=True, drop_last=True)

In [4]:
# model = VisionTransformer(trains_shape, patch_size, 4, 10, 768, 12, 3072, 0.1, device='cuda:0')
device = "cpu"
if is_available():
    device = "cuda:0"
model = VisionTransformer((16, 3, 224, 224), 16, 12, 4, 768, 12, 3072, 0.1, device=device)
# if device == 'cuda:0':
    # if device_count() > 1:
    #     model = DataParallel(model)
model.to(device)
summary(model)

Layer (type:depth-idx)                   Param #
VisionTransformer                        --
├─PatchEmbedding: 1-1                    163,584
│    └─Flatten: 2-1                      --
├─Linear: 1-2                            2,362,368
├─Linear: 1-3                            12,292
Total params: 2,538,244
Trainable params: 2,538,244
Non-trainable params: 0

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-4)
num_epochs = 1000

train_and_evaluate(model, train_loader, None, criterion, optimizer, num_epochs, device)