In [None]:
import torch
from torch.nn import Linear, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from torchvision.transforms import ToTensor, Resize, Compose, RandomGrayscale, RandomHorizontalFlip, RandomResizedCrop
from torchvision.models import resnet101
from collections import Counter
from pathlib import Path
import os
from time import time
import random

from train_and_test_classification import seed_all, train_test_classifier

In [None]:
seed_all()

In [None]:
project_root = Path("/project_root")
dataset_root = project_root/"Datasets"/"pokemon_dataset"
run_folder = project_root/"runs"/"resnet_pokemon_dataset"
run_folder.mkdir(exist_ok=True, parents=True)
current_run_folder = run_folder/f"{int(time())}"
current_run_folder.mkdir(exist_ok=False)

In [None]:
test_network = resnet101(pretrained=False)
test_network.fc = Linear(2048, 149)
print(test_network)
with torch.no_grad():
  X = torch.ones((1, 3, 224, 224))
  y = test_network(X)
  assert y.shape == (1, 149)

In [None]:
def get_file_list(root):
  if not root.exists():
        raise FileNotFoundError(f"Dataset folder doesn't exist. Path : {root}")
  filelist = []
  svg_count = 0
  subfolders = root.iterdir()
  subfolders = list(subfolders)
  for folder in subfolders:
    if not folder.is_dir():
      raise ValueError("Root Folder should not have files, only subfolders")
  classes = [os.path.basename(folder) for folder in subfolders]
  for folder in subfolders:
    files = folder.glob("*.*")
    for file in files:
      if file.suffix == ".svg":
        svg_count = svg_count + 1
        continue
      filelist.append(file)
  print(f"SVG Number: {svg_count}")
  print(f"Retained: {len(filelist)}")
  return filelist, classes

In [None]:
all_files, classes = get_file_list(dataset_root)
random.seed(2)
random.shuffle(all_files)
train_len = int(0.7 * len(all_files))
val_len = int(0.15 * len(all_files))
test_len = len(all_files) - train_len - val_len
train_files = all_files[:train_len]
val_files = all_files[train_len:train_len+val_len]
test_files = all_files[-test_len:]

In [None]:
class ImageDataset(Dataset):
  def __init__(self, filelist, classes, transforms):
    self.transforms = transforms
    self.filelist = filelist
    labellist = []
    for file in filelist:
      folder = file.parent
      labellist.append(classes.index(os.path.basename(folder)))
    self.labellist = labellist  

  def __getitem__(self, index):
    return self.transforms(default_loader(self.filelist[index])), self.labellist[index]

  def __len__(self):
    return len(self.labellist)

  def get_label_list(self):
    return self.labellist

In [None]:
train_data = ImageDataset(filelist=train_files, classes=classes, transforms=Compose([ToTensor(), 
                                                                                     RandomResizedCrop((224,224)), 
                                                                                     RandomHorizontalFlip(), 
                                                                                     RandomGrayscale()]))
val_data = ImageDataset(filelist=val_files, classes=classes, transforms=Compose([ToTensor(), Resize((224,224))]))
test_data = ImageDataset(filelist=test_files, classes=classes, transforms=Compose([ToTensor(), Resize((224,224))]))

In [None]:
train_counts = Counter(train_data.get_label_list())
val_counts = Counter(val_data.get_label_list())
test_counts = Counter(test_data.get_label_list())

print(f"Class Name\t\tTrain Count\tVal Count\tTest_Count\n")
for class_index, class_name in enumerate(classes):
  print(f"{class_name: <20}\t\t{train_counts[class_index]}\t\t{val_counts[class_index]}\t\t{test_counts[class_index]}")

In [None]:
logger = SummaryWriter(current_run_folder/"logs")
model = resnet101(pretrained=True)
model.fc = Linear(2048, len(classes))
device = "cuda" if torch.cuda.device_count() > 0 else "cpu"
model = model.to(device)
logger.add_graph(model, torch.ones(1, 3, 224, 224).to(device))
checkpoint_folder = current_run_folder/"checkpoints"
checkpoint_folder.mkdir(exist_ok=False)
train_test_classifier(model=model,
                      train_data=train_data,
                      val_data=val_data,
                      test_data=test_data,
                      batch_size=64,
                      num_epochs=100,
                      loss_function=CrossEntropyLoss(),
                      optimizer=Adam(model.parameters(), lr = 0.0003),
                      logger=logger,
                      device=device, 
                      checkpoint_folder=checkpoint_folder,
                      early_stopping_epochs=5)