In [None]:
import torch
from torch import optim
from PIL import Image
from resnet18 import ResNet18
from utils import *
import json
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, random_split
import time
import os
from torchvision import transforms

In [None]:
class chessDataset(Dataset):
    def __init__(self, folder_path, transform = None):
        'Initialization'
        self.folder_path = folder_path
        self.transform = transform
    
    def __len__(self):
        'Denotes the total number of samples'
        return len(os.listdir(self.folder_path)) * 64
    
    def __getitem__(self, index):
        file_nb = index // 64
        square = file_nb % 64
        square_i = square // 8
        square_j = square % 8
        file = os.listdir(self.folder_path)[file_nb]
        image = Image.open(os.path.join(self.folder_path,file))
        label = int(fen_to_labels(file.split('.')[0]).reshape(64,)[square].item())
        if self.transform:
            features = self.transform(image)
        features = features[:,50*square_i:50*(square_i+1),50*square_j:50*(square_j+1)]
        return features, label

In [None]:
class Arguments:
    # Model
    model: str='resnet18'
    
    #Data
    batch_size: int = 32
        
    # Optimization
    epoch_max_iter: int = 1000
    optimizer: str = 'adamw'  # [sgd, momentum, adam, adamw]
    epochs: int = 1
    lr: float = 5e-5
    momentum: float = 0.9
    weight_decay: float = 5e-3

    # Experiment
    datadir: str = 'dataset'
    logdir: str = 'logs'
    modeldir: str = 'models'
    seed: int = 420

    # Miscellaneous
    device: str = 'cuda'
    print_every: int = 100

In [None]:
train_transform = transforms.Compose([
    transforms.GaussianBlur(3),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor()])

test_transform = transforms.Compose([
    transforms.ToTensor()])

In [None]:
def train(epoch, model, dataloader, optimizer, args):
    model.train()
    total_iters = 0
    epoch_accuracy=0
    epoch_loss=0
    start_time = time.time()
    for idx, batch in enumerate(dataloader):
        if idx > args.epoch_max_iter:
            break
        batch = to_device(batch, args.device)
        optimizer.zero_grad()
        imgs, labels = batch
        logits = model(imgs)
        loss = cross_entropy_loss(logits, labels)
        acc = compute_accuracy(logits, labels)

        loss.backward()
        optimizer.step()
        epoch_accuracy += acc.item() / args.epoch_max_iter
        epoch_loss += loss.item() / args.epoch_max_iter
        total_iters += 1

        if idx % args.print_every == 0:
            tqdm.write(f"[TRAIN] Epoch: {epoch}, Iter: {idx} out of {args.epoch_max_iter}, Loss: {loss.item():.5f}")
    tqdm.write(f"== [TRAIN] Epoch: {epoch}, Accuracy: {epoch_accuracy:.3f} ==>")
    return epoch_loss, epoch_accuracy, time.time() - start_time


def evaluate(epoch, model, dataloader, args, mode="val"):
    model.eval()
    epoch_accuracy=0
    epoch_loss=0
    total_iters = 0
    start_time = time.time()
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            if idx > args.epoch_max_iter:
                break
            batch = to_device(batch, args.device)
            imgs, labels = batch
            logits = model(imgs)
            loss = cross_entropy_loss(logits, labels)
            acc = compute_accuracy(logits, labels)
            epoch_accuracy += acc.item() / args.epoch_max_iter
            epoch_loss += loss.item() / args.epoch_max_iter
            total_iters += 1
            if idx % args.print_every == 0:
                tqdm.write(
                    f"[{mode.upper()}] Epoch: {epoch}, Iter: {idx} out of {args.epoch_max_iter}, Loss: {loss.item():.5f}"
                )
        tqdm.write(
            f"=== [{mode.upper()}] Epoch: {epoch}, Iter: {idx}, Accuracy: {epoch_accuracy:.3f} ===>"
        )
    return epoch_loss, epoch_accuracy, time.time() - start_time

In [None]:
def main_training(args):
    seed_experiment(args.seed)
    train_set = chessDataset(os.path.join(args.datadir,'train'), train_transform)
    train_set, val_set = random_split(train_set, [0.8, 0.2])
    val_set.transform = test_transform
    test_set = chessDataset(os.path.join(args.datadir,'train'), test_transform)
    train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=2)
    valid_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=2)
    test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=2)
    
    # Load model
    print(f'Build model {args.model.upper()}...')
    print('############################################')
    model_cls = {'resnet18': ResNet18}[args.model]
    model = model_cls(num_classes=13)
    model.to(args.device)
    
    # Optimizer
    if args.optimizer == "adamw":
        optimizer = optim.AdamW(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )
    elif args.optimizer == "adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optimizer == "sgd":
        optimizer = optim.SGD(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )
    elif args.optimizer == "momentum":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    
    print(
        f"Initialized {args.model.upper()} model with {sum(p.numel() for p in model.parameters())} "
        f"total parameters, of which {sum(p.numel() for p in model.parameters() if p.requires_grad)} are learnable."
    )
    
    train_losses, valid_losses = [], []
    train_accs, valid_accs = [], []
    train_times, valid_times = [], []
    
    for epoch in range(args.epochs):
        tqdm.write(f"====== Epoch {epoch} ======>")
        loss, acc, wall_time = train(epoch, model, train_dataloader, optimizer,args)
        train_losses.append(loss)
        train_accs.append(acc)
        train_times.append(wall_time)

        loss, acc, wall_time = evaluate(epoch, model, valid_dataloader,args)
        valid_losses.append(loss)
        valid_accs.append(acc)
        valid_times.append(wall_time)

    test_loss, test_acc, test_time = evaluate(
        epoch, model, test_dataloader, args, mode="test"
    )
    print(f"===== Best validation Accuracy: {max(valid_accs):.3f} =====>")

    # Save log if logdir provided
    if args.logdir is not None:
        print(f'Writing training logs to {args.logdir}...')
        os.makedirs(args.logdir, exist_ok=True)
        with open(os.path.join(args.logdir, 'results.json'), 'w') as f:
            f.write(json.dumps(
                {
                    "train_losses": train_losses,
                    "valid_losses": valid_losses,
                    "train_accs": train_accs,
                    "valid_accs": valid_accs,
                    "test_loss": test_loss,
                    "test_acc": test_acc
                },
                indent=4,
            ))
        torch.save(model.state_dict(), os.path.join(args.modeldir,'resnet18.pth'))

In [None]:
args = Arguments()
main_training(args)