In [1]:
import os
import json
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import v2 as transforms
from sklearn.metrics import accuracy_score,mean_absolute_error
from tqdm import tqdm
from torchvision import datasets, models
from PIL import Image
datasetPath = ".."
anotationsPath = "..\\dataset\\anotations"


# Load the annotations
with open(os.path.join(anotationsPath, "annotations.json")) as f:
    annotations = json.load(f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create labels


# Create Dataset

In [2]:
class CustomDataset(Dataset):
    def __init__(self, datasetPath, imageIds,annotations,ImgPaths, transform=None):
        self.datasetPath = datasetPath
        self.imageIds = imageIds
        self.labels = {}
        self.imagePaths = {annotation['id']: annotation['path'] for annotation in ImgPaths}
        self.transform = transform
        self.load_labels(annotations)
        self.images = {}
        for image_id in self.imageIds:
            image_path = os.path.join(self.datasetPath, self.imagePaths[image_id])
            img = Image.open(image_path).convert("RGB")
            if self.transform:
                img = self.transform(img)
            else:
                img = transforms.ToTensor()(img)
            self.images[image_id] = img
    def __len__(self):
        return len(self.imageIds)
    def load_labels(self,annotations):
        for annotation in annotations:
            label = annotation['image_id']
            if label not in self.imageIds:
                continue
            if label not in self.labels:
                self.labels[label] = 0
            else:
                self.labels[label] += 1
    def __getitem__(self, idx):
        #print(self.imageIds[idx])
        image_id =self.imageIds[idx]
        img =self.images[image_id]

        label = torch.tensor(self.labels[image_id], dtype=torch.float32)
        return img, label
#print(annotations['annotations']["pieces"])



# Training

In [3]:
def epoch_iter(dataloader, model, loss_fn, optimizer=None, is_train=True):
    if is_train:
      assert optimizer is not None, "When training, please provide an optimizer."
      
    num_batches = len(dataloader)

    if is_train:
      model.train() # put model in train mode
    else:
      model.eval()

    total_loss = 0.0
    preds = []
    labels = []

    with torch.set_grad_enabled(is_train):
      for batch, (X, y) in enumerate(tqdm(dataloader)):
          X, y = X.to(device), y.to(device)

          # Compute prediction error
          pred = model(X)
          pred = pred.squeeze()
          #print(pred.shape)
          #print(y.shape)
          loss = loss_fn(pred, y)

          if is_train:
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

          # Save training metrics
          total_loss += loss.item() # IMPORTANT: call .item() to obtain the value of the loss WITHOUT the computational graph attached

          preds.extend(pred.cpu().detach().numpy())
          labels.extend(y.cpu().numpy())

    return total_loss / num_batches, mean_absolute_error(labels, preds)

Epoch

In [4]:
def train(model, model_name, num_epochs, train_dataloader, validation_dataloader, loss_fn, optimizer):
  train_history = {'loss': [], 'accuracy': []}
  val_history = {'loss': [], 'accuracy': []}
  best_val_loss = np.inf
  print("Start training...")
  for t in range(num_epochs):
      print(f"\nEpoch {t+1}")
      train_loss, train_acc = epoch_iter(train_dataloader, model, loss_fn, optimizer)
      print(f"Train loss: {train_loss:.3f} \t Train acc: {train_acc:.3f}")
      val_loss, val_acc = epoch_iter(validation_dataloader, model, loss_fn, is_train=False)
      print(f"Val loss: {val_loss:.3f} \t Val acc: {val_acc:.3f}")

      # save model when val loss improves
      if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': t}
        torch.save(save_dict, model_name + '_best_model.pth')

      # save latest model
      save_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': t}
      torch.save(save_dict, model_name + '_latest_model.pth')

      # save training history for plotting purposes
      train_history["loss"].append(train_loss)
      train_history["accuracy"].append(train_acc)

      val_history["loss"].append(val_loss)
      val_history["accuracy"].append(val_acc)
      
  print("Finished")
  return train_history, val_history

In [5]:

# Load ResNet model from torchvision (with pretrained=True)
weights = models.ResNet50_Weights.IMAGENET1K_V2
resnet = models.resnet50(weights=weights)
# Freeze all layers except the last one
#for param in resnet.parameters():
#    param.requires_grad = False
transforms = weights.transforms()
# Change the number of neurons in the last layer to the number of classes of the CIFAR10 dataset
# TODO
num_classes = 1
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)


resnet = resnet.to(device)


# Define the model

In [6]:








testDataset = CustomDataset(datasetPath, annotations['splits']['chessred2k']['test']['image_ids'],annotations['annotations']["pieces"],annotations['images'],transforms)
testLoader = DataLoader(testDataset, batch_size=10, shuffle=False, num_workers=0)

trainDataset = CustomDataset(datasetPath, annotations['splits']['chessred2k']['train']['image_ids'],annotations['annotations']["pieces"],annotations['images'],transforms)
trainLoader = DataLoader(trainDataset, batch_size=10, shuffle=False, num_workers=0)
validationDataset = CustomDataset(datasetPath, annotations['splits']['chessred2k']['val']['image_ids'],annotations['annotations']["pieces"],annotations['images'],transforms)
validationLoader = DataLoader(validationDataset, batch_size=10, shuffle=False, num_workers=0)


In [8]:
train(resnet, "resnet50", 50, trainLoader, validationLoader, nn.MSELoss(), torch.optim.Adam(resnet.parameters(), lr=5e-5))

Start training...

Epoch 1


100%|██████████| 145/145 [00:11<00:00, 12.95it/s]


Train loss: 6.651 	 Train acc: 1.979


100%|██████████| 33/33 [00:00<00:00, 40.94it/s]


Val loss: 17.919 	 Val acc: 3.725

Epoch 2


100%|██████████| 145/145 [00:11<00:00, 13.10it/s]


Train loss: 1.891 	 Train acc: 1.083


100%|██████████| 33/33 [00:00<00:00, 39.77it/s]


Val loss: 18.494 	 Val acc: 3.783

Epoch 3


100%|██████████| 145/145 [00:11<00:00, 12.99it/s]


Train loss: 1.098 	 Train acc: 0.818


100%|██████████| 33/33 [00:00<00:00, 37.93it/s]


Val loss: 18.072 	 Val acc: 3.732

Epoch 4


100%|██████████| 145/145 [00:11<00:00, 13.17it/s]


Train loss: 0.650 	 Train acc: 0.621


100%|██████████| 33/33 [00:00<00:00, 40.75it/s]


Val loss: 17.708 	 Val acc: 3.682

Epoch 5


100%|██████████| 145/145 [00:11<00:00, 13.17it/s]


Train loss: 0.384 	 Train acc: 0.470


100%|██████████| 33/33 [00:00<00:00, 40.09it/s]


Val loss: 17.325 	 Val acc: 3.632

Epoch 6


100%|██████████| 145/145 [00:11<00:00, 13.04it/s]


Train loss: 0.225 	 Train acc: 0.354


100%|██████████| 33/33 [00:00<00:00, 39.74it/s]


Val loss: 17.051 	 Val acc: 3.593

Epoch 7


100%|██████████| 145/145 [00:10<00:00, 13.33it/s]


Train loss: 0.131 	 Train acc: 0.268


100%|██████████| 33/33 [00:00<00:00, 39.84it/s]


Val loss: 16.749 	 Val acc: 3.550

Epoch 8


100%|██████████| 145/145 [00:11<00:00, 13.16it/s]


Train loss: 0.089 	 Train acc: 0.219


100%|██████████| 33/33 [00:00<00:00, 40.56it/s]


Val loss: 16.620 	 Val acc: 3.530

Epoch 9


100%|██████████| 145/145 [00:10<00:00, 13.19it/s]


Train loss: 0.090 	 Train acc: 0.203


100%|██████████| 33/33 [00:00<00:00, 42.53it/s]


Val loss: 16.091 	 Val acc: 3.470

Epoch 10


100%|██████████| 145/145 [00:10<00:00, 13.97it/s]


Train loss: 0.126 	 Train acc: 0.232


100%|██████████| 33/33 [00:00<00:00, 43.45it/s]


Val loss: 16.193 	 Val acc: 3.481

Epoch 11


100%|██████████| 145/145 [00:10<00:00, 14.00it/s]


Train loss: 0.222 	 Train acc: 0.319


100%|██████████| 33/33 [00:00<00:00, 39.86it/s]


Val loss: 16.290 	 Val acc: 3.531

Epoch 12


100%|██████████| 145/145 [00:10<00:00, 13.76it/s]


Train loss: 0.236 	 Train acc: 0.375


100%|██████████| 33/33 [00:00<00:00, 40.63it/s]


Val loss: 15.405 	 Val acc: 3.427

Epoch 13


100%|██████████| 145/145 [00:10<00:00, 13.81it/s]


Train loss: 0.611 	 Train acc: 0.630


100%|██████████| 33/33 [00:00<00:00, 40.01it/s]


Val loss: 17.447 	 Val acc: 3.642

Epoch 14


100%|██████████| 145/145 [00:10<00:00, 13.96it/s]


Train loss: 1.207 	 Train acc: 0.888


100%|██████████| 33/33 [00:00<00:00, 42.86it/s]


Val loss: 16.821 	 Val acc: 3.594

Epoch 15


100%|██████████| 145/145 [00:10<00:00, 13.62it/s]


Train loss: 1.560 	 Train acc: 1.004


100%|██████████| 33/33 [00:00<00:00, 42.46it/s]


Val loss: 15.470 	 Val acc: 3.459

Epoch 16


100%|██████████| 145/145 [00:10<00:00, 13.88it/s]


Train loss: 1.204 	 Train acc: 0.855


100%|██████████| 33/33 [00:00<00:00, 43.46it/s]


Val loss: 14.936 	 Val acc: 3.364

Epoch 17


100%|██████████| 145/145 [00:10<00:00, 14.13it/s]


Train loss: 1.407 	 Train acc: 0.962


100%|██████████| 33/33 [00:00<00:00, 42.00it/s]


Val loss: 14.986 	 Val acc: 3.366

Epoch 18


100%|██████████| 145/145 [00:10<00:00, 13.89it/s]


Train loss: 1.296 	 Train acc: 0.937


100%|██████████| 33/33 [00:00<00:00, 41.87it/s]


Val loss: 16.264 	 Val acc: 3.461

Epoch 19


100%|██████████| 145/145 [00:10<00:00, 13.98it/s]


Train loss: 0.937 	 Train acc: 0.799


100%|██████████| 33/33 [00:00<00:00, 39.98it/s]


Val loss: 16.488 	 Val acc: 3.455

Epoch 20


100%|██████████| 145/145 [00:10<00:00, 13.71it/s]


Train loss: 0.701 	 Train acc: 0.698


100%|██████████| 33/33 [00:00<00:00, 39.18it/s]


Val loss: 14.962 	 Val acc: 3.307

Epoch 21


100%|██████████| 145/145 [00:10<00:00, 13.81it/s]


Train loss: 0.644 	 Train acc: 0.660


100%|██████████| 33/33 [00:00<00:00, 42.83it/s]


Val loss: 14.694 	 Val acc: 3.352

Epoch 22


100%|██████████| 145/145 [00:10<00:00, 14.49it/s]


Train loss: 0.579 	 Train acc: 0.615


100%|██████████| 33/33 [00:00<00:00, 43.16it/s]


Val loss: 15.745 	 Val acc: 3.486

Epoch 23


100%|██████████| 145/145 [00:10<00:00, 14.18it/s]


Train loss: 0.620 	 Train acc: 0.659


100%|██████████| 33/33 [00:00<00:00, 44.52it/s]


Val loss: 15.989 	 Val acc: 3.520

Epoch 24


100%|██████████| 145/145 [00:10<00:00, 14.24it/s]


Train loss: 0.683 	 Train acc: 0.692


100%|██████████| 33/33 [00:00<00:00, 44.13it/s]


Val loss: 15.264 	 Val acc: 3.448

Epoch 25


100%|██████████| 145/145 [00:10<00:00, 14.21it/s]


Train loss: 0.749 	 Train acc: 0.711


100%|██████████| 33/33 [00:00<00:00, 42.18it/s]


Val loss: 14.329 	 Val acc: 3.277

Epoch 26


100%|██████████| 145/145 [00:10<00:00, 14.32it/s]


Train loss: 0.698 	 Train acc: 0.676


100%|██████████| 33/33 [00:00<00:00, 41.22it/s]


Val loss: 14.980 	 Val acc: 3.250

Epoch 27


100%|██████████| 145/145 [00:11<00:00, 12.99it/s]


Train loss: 0.553 	 Train acc: 0.618


100%|██████████| 33/33 [00:00<00:00, 38.63it/s]


Val loss: 15.411 	 Val acc: 3.287

Epoch 28


100%|██████████| 145/145 [00:10<00:00, 13.94it/s]


Train loss: 0.441 	 Train acc: 0.554


100%|██████████| 33/33 [00:00<00:00, 41.37it/s]


Val loss: 14.502 	 Val acc: 3.250

Epoch 29


100%|██████████| 145/145 [00:10<00:00, 14.18it/s]


Train loss: 0.407 	 Train acc: 0.517


100%|██████████| 33/33 [00:00<00:00, 44.38it/s]


Val loss: 14.533 	 Val acc: 3.343

Epoch 30


100%|██████████| 145/145 [00:10<00:00, 13.46it/s]


Train loss: 0.391 	 Train acc: 0.510


100%|██████████| 33/33 [00:00<00:00, 40.02it/s]


Val loss: 15.059 	 Val acc: 3.416

Epoch 31


100%|██████████| 145/145 [00:10<00:00, 13.23it/s]


Train loss: 0.477 	 Train acc: 0.569


100%|██████████| 33/33 [00:00<00:00, 39.52it/s]


Val loss: 14.485 	 Val acc: 3.361

Epoch 32


100%|██████████| 145/145 [00:10<00:00, 13.61it/s]


Train loss: 0.571 	 Train acc: 0.613


100%|██████████| 33/33 [00:00<00:00, 40.44it/s]


Val loss: 13.161 	 Val acc: 3.095

Epoch 33


100%|██████████| 145/145 [00:11<00:00, 13.09it/s]


Train loss: 0.509 	 Train acc: 0.573


100%|██████████| 33/33 [00:00<00:00, 38.70it/s]


Val loss: 14.651 	 Val acc: 3.203

Epoch 34


100%|██████████| 145/145 [00:10<00:00, 13.42it/s]


Train loss: 0.450 	 Train acc: 0.551


100%|██████████| 33/33 [00:00<00:00, 42.90it/s]


Val loss: 15.818 	 Val acc: 3.345

Epoch 35


100%|██████████| 145/145 [00:10<00:00, 13.86it/s]


Train loss: 0.443 	 Train acc: 0.543


100%|██████████| 33/33 [00:00<00:00, 43.25it/s]


Val loss: 15.293 	 Val acc: 3.421

Epoch 36


100%|██████████| 145/145 [00:10<00:00, 13.61it/s]


Train loss: 0.366 	 Train acc: 0.486


100%|██████████| 33/33 [00:00<00:00, 42.39it/s]


Val loss: 15.377 	 Val acc: 3.435

Epoch 37


100%|██████████| 145/145 [00:10<00:00, 13.87it/s]


Train loss: 0.427 	 Train acc: 0.523


100%|██████████| 33/33 [00:00<00:00, 39.18it/s]


Val loss: 14.352 	 Val acc: 3.331

Epoch 38


100%|██████████| 145/145 [00:10<00:00, 13.75it/s]


Train loss: 0.513 	 Train acc: 0.575


100%|██████████| 33/33 [00:00<00:00, 39.96it/s]


Val loss: 13.150 	 Val acc: 3.110

Epoch 39


100%|██████████| 145/145 [00:10<00:00, 13.62it/s]


Train loss: 0.427 	 Train acc: 0.512


100%|██████████| 33/33 [00:00<00:00, 40.48it/s]


Val loss: 15.773 	 Val acc: 3.403

Epoch 40


100%|██████████| 145/145 [00:10<00:00, 13.45it/s]


Train loss: 0.382 	 Train acc: 0.505


100%|██████████| 33/33 [00:00<00:00, 41.88it/s]


Val loss: 16.373 	 Val acc: 3.504

Epoch 41


100%|██████████| 145/145 [00:10<00:00, 13.37it/s]


Train loss: 0.413 	 Train acc: 0.513


100%|██████████| 33/33 [00:00<00:00, 41.19it/s]


Val loss: 17.001 	 Val acc: 3.600

Epoch 42


100%|██████████| 145/145 [00:10<00:00, 13.64it/s]


Train loss: 0.394 	 Train acc: 0.501


100%|██████████| 33/33 [00:00<00:00, 40.36it/s]


Val loss: 15.082 	 Val acc: 3.415

Epoch 43


100%|██████████| 145/145 [00:10<00:00, 13.53it/s]


Train loss: 0.484 	 Train acc: 0.543


100%|██████████| 33/33 [00:00<00:00, 39.88it/s]


Val loss: 14.418 	 Val acc: 3.278

Epoch 44


100%|██████████| 145/145 [00:10<00:00, 13.48it/s]


Train loss: 0.465 	 Train acc: 0.522


100%|██████████| 33/33 [00:00<00:00, 39.37it/s]


Val loss: 15.312 	 Val acc: 3.413

Epoch 45


100%|██████████| 145/145 [00:10<00:00, 13.61it/s]


Train loss: 0.370 	 Train acc: 0.462


100%|██████████| 33/33 [00:00<00:00, 40.47it/s]


Val loss: 16.789 	 Val acc: 3.614

Epoch 46


100%|██████████| 145/145 [00:10<00:00, 13.44it/s]


Train loss: 0.389 	 Train acc: 0.495


100%|██████████| 33/33 [00:00<00:00, 41.87it/s]


Val loss: 15.911 	 Val acc: 3.492

Epoch 47


100%|██████████| 145/145 [00:10<00:00, 13.49it/s]


Train loss: 0.405 	 Train acc: 0.519


100%|██████████| 33/33 [00:00<00:00, 42.54it/s]


Val loss: 14.077 	 Val acc: 3.283

Epoch 48


100%|██████████| 145/145 [00:10<00:00, 13.52it/s]


Train loss: 0.450 	 Train acc: 0.531


100%|██████████| 33/33 [00:00<00:00, 40.71it/s]


Val loss: 15.438 	 Val acc: 3.402

Epoch 49


100%|██████████| 145/145 [00:10<00:00, 13.59it/s]


Train loss: 0.398 	 Train acc: 0.505


100%|██████████| 33/33 [00:00<00:00, 37.06it/s]


Val loss: 18.064 	 Val acc: 3.679

Epoch 50


100%|██████████| 145/145 [00:10<00:00, 13.58it/s]


Train loss: 0.506 	 Train acc: 0.561


100%|██████████| 33/33 [00:00<00:00, 41.55it/s]


Val loss: 16.460 	 Val acc: 3.571
Finished


({'loss': [6.6507436414217125,
   1.8913418828413404,
   1.0983278893608728,
   0.6498562832844669,
   0.3840759050614875,
   0.22486477206493247,
   0.13057481836656043,
   0.0893536129163514,
   0.08954448015910799,
   0.12575054579125397,
   0.22231028902761896,
   0.23606640462731493,
   0.611378651857376,
   1.2066227317370217,
   1.560108596163577,
   1.2043701273613963,
   1.4072996221482754,
   1.296275490727918,
   0.9370329189467533,
   0.7008833112965884,
   0.6441257041589967,
   0.578833111850866,
   0.6202734758627826,
   0.6827499552565659,
   0.7489048665710565,
   0.6979218454453452,
   0.5530330035321671,
   0.4407086107109127,
   0.40679241468937233,
   0.39106888406235596,
   0.4774621525843596,
   0.5714256770523458,
   0.5089882780251832,
   0.44952424774909844,
   0.44260168142359835,
   0.3657361792400479,
   0.4265965479715117,
   0.5130489885807037,
   0.42691646223438195,
   0.3816845606370219,
   0.41298458988553494,
   0.3944283513544962,
   0.4837625167750