In [1]:
import os
import json
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import v2 as transforms
from sklearn.metrics import accuracy_score,mean_absolute_error
from tqdm import tqdm
from torchvision import datasets, models
from PIL import Image
datasetPath = ".."
anotationsPath = "..\\dataset\\anotations"


# Load the annotations
with open(os.path.join(anotationsPath, "annotations.json")) as f:
    annotations = json.load(f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create labels


# Create Dataset

In [2]:
class CustomDataset(Dataset):
    def __init__(self, datasetPath, imageIds,annotations,ImgPaths, transform=None):
        self.datasetPath = datasetPath
        self.imageIds = imageIds
        self.labels = {}
        self.imagePaths = {annotation['id']: annotation['path'] for annotation in ImgPaths}
        self.transform = transform
        self.load_labels(annotations)
    def __len__(self):
        return len(self.imageIds)
    def load_labels(self,annotations):
        for annotation in annotations:
            label = annotation['image_id']
            if label not in self.imageIds:
                continue
            if label not in self.labels:
                self.labels[label] = 0
            else:
                self.labels[label] += 1
    def __getitem__(self, idx):
        #print(self.imageIds[idx])
        image_id = self.imageIds[idx]
        image_path = os.path.join(self.datasetPath, self.imagePaths[image_id])
        img = Image.open(image_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)
        label = torch.tensor(self.labels[image_id], dtype=torch.float32)
        return img, label
#print(annotations['annotations']["pieces"])



# Training

In [3]:
def epoch_iter(dataloader, model, loss_fn, optimizer=None, is_train=True):
    if is_train:
      assert optimizer is not None, "When training, please provide an optimizer."
      
    num_batches = len(dataloader)

    if is_train:
      model.train() # put model in train mode
    else:
      model.eval()

    total_loss = 0.0
    preds = []
    labels = []

    with torch.set_grad_enabled(is_train):
      for batch, (X, y) in enumerate(tqdm(dataloader)):
          X, y = X.to(device), y.to(device)

          # Compute prediction error
          pred = model(X)
          #print(pred)
          #print(y)
          loss = loss_fn(pred, y)

          if is_train:
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

          # Save training metrics
          total_loss += loss.item() # IMPORTANT: call .item() to obtain the value of the loss WITHOUT the computational graph attached

          probs = F.softmax(pred, dim=1)
          final_pred = torch.argmax(probs, dim=1)
          preds.extend(final_pred.cpu().numpy())
          labels.extend(y.cpu().numpy())

    return total_loss / num_batches, mean_absolute_error(labels, preds)

Epoch

In [4]:
def train(model, model_name, num_epochs, train_dataloader, validation_dataloader, loss_fn, optimizer):
  train_history = {'loss': [], 'accuracy': []}
  val_history = {'loss': [], 'accuracy': []}
  best_val_loss = np.inf
  print("Start training...")
  for t in range(num_epochs):
      print(f"\nEpoch {t+1}")
      train_loss, train_acc = epoch_iter(train_dataloader, model, loss_fn, optimizer)
      print(f"Train loss: {train_loss:.3f} \t Train acc: {train_acc:.3f}")
      val_loss, val_acc = epoch_iter(validation_dataloader, model, loss_fn, is_train=False)
      print(f"Val loss: {val_loss:.3f} \t Val acc: {val_acc:.3f}")

      # save model when val loss improves
      if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': t}
        torch.save(save_dict, model_name + '_best_model.pth')

      # save latest model
      save_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': t}
      torch.save(save_dict, model_name + '_latest_model.pth')

      # save training history for plotting purposes
      train_history["loss"].append(train_loss)
      train_history["accuracy"].append(train_acc)

      val_history["loss"].append(val_loss)
      val_history["accuracy"].append(val_acc)
      
  print("Finished")
  return train_history, val_history

In [5]:

# Load ResNet model from torchvision (with pretrained=True)
weights = models.ResNet50_Weights.IMAGENET1K_V2
resnet = models.resnet50(weights=weights)
# Freeze all layers except the last one
for param in resnet.parameters():
    param.requires_grad = False
transforms = weights.transforms()
# Change the number of neurons in the last layer to the number of classes of the CIFAR10 dataset
# TODO
num_classes = 1
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)


resnet = resnet.to(device)


# Define the model

In [6]:








testDataset = CustomDataset(datasetPath, annotations['splits']['chessred2k']['test']['image_ids'],annotations['annotations']["pieces"],annotations['images'],transforms)
testLoader = DataLoader(testDataset, batch_size=10, shuffle=False, num_workers=0)

trainDataset = CustomDataset(datasetPath, annotations['splits']['chessred2k']['train']['image_ids'],annotations['annotations']["pieces"],annotations['images'],transforms)
trainLoader = DataLoader(trainDataset, batch_size=10, shuffle=False, num_workers=0)
validationDataset = CustomDataset(datasetPath, annotations['splits']['chessred2k']['val']['image_ids'],annotations['annotations']["pieces"],annotations['images'],transforms)
validationLoader = DataLoader(validationDataset, batch_size=10, shuffle=False, num_workers=0)


In [7]:
train(resnet, "resnet50", 10, trainLoader, validationLoader, nn.MSELoss(), torch.optim.Adam(resnet.parameters(), lr=1e-3))

Start training...

Epoch 1


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:08<00:00,  1.12it/s]


Train loss: 217.232 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:29<00:00,  1.12it/s]


Val loss: 103.839 	 Val acc: 17.582

Epoch 2


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:09<00:00,  1.12it/s]


Train loss: 87.468 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:29<00:00,  1.13it/s]


Val loss: 91.012 	 Val acc: 17.582

Epoch 3


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:07<00:00,  1.14it/s]


Train loss: 82.468 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:29<00:00,  1.12it/s]


Val loss: 88.850 	 Val acc: 17.582

Epoch 4


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:07<00:00,  1.14it/s]


Train loss: 79.906 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:29<00:00,  1.13it/s]


Val loss: 87.029 	 Val acc: 17.582

Epoch 5


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:04<00:00,  1.16it/s]


Train loss: 77.564 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:27<00:00,  1.21it/s]


Val loss: 85.515 	 Val acc: 17.582

Epoch 6


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:02<00:00,  1.18it/s]


Train loss: 75.465 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:28<00:00,  1.16it/s]


Val loss: 84.251 	 Val acc: 17.582

Epoch 7


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:03<00:00,  1.17it/s]


Train loss: 73.559 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:28<00:00,  1.16it/s]


Val loss: 83.179 	 Val acc: 17.582

Epoch 8


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:03<00:00,  1.18it/s]


Train loss: 71.804 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:28<00:00,  1.14it/s]


Val loss: 82.254 	 Val acc: 17.582

Epoch 9


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:07<00:00,  1.14it/s]


Train loss: 70.170 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:29<00:00,  1.13it/s]


Val loss: 81.442 	 Val acc: 17.582

Epoch 10


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 145/145 [02:10<00:00,  1.11it/s]


Train loss: 68.635 	 Train acc: 20.306


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 33/33 [00:29<00:00,  1.12it/s]


Val loss: 80.717 	 Val acc: 17.582
Finished


({'loss': [217.23223784792012,
   87.46793982078289,
   82.46784369698887,
   79.90599397133137,
   77.56404205190724,
   75.46497654421576,
   73.55933351845577,
   71.80443630218505,
   70.16990786585315,
   68.63495263066785],
  'accuracy': [20.305825242718445,
   20.305825242718445,
   20.305825242718445,
   20.305825242718445,
   20.305825242718445,
   20.305825242718445,
   20.305825242718445,
   20.305825242718445,
   20.305825242718445,
   20.305825242718445]},
 {'loss': [103.83909573699489,
   91.01164612625584,
   88.84998522382794,
   87.02921091426502,
   85.51486850507331,
   84.25051385706121,
   83.17861334482829,
   82.2537136222377,
   81.44172523960923,
   80.71741763028231],
  'accuracy': [17.581818181818182,
   17.581818181818182,
   17.581818181818182,
   17.581818181818182,
   17.581818181818182,
   17.581818181818182,
   17.581818181818182,
   17.581818181818182,
   17.581818181818182,
   17.581818181818182]})