In [None]:
#import libraries
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.utils.tensorboard as tb
import torch.utils.data as data

In [None]:
!pip install wandb
!wandb login

In [None]:
import wandb
wandb.init(project="CNN")

In [None]:
class ClassificationMetrics:
  def __init__(self, num_classes=10):
    self.num_classes = num_classes
    self.C = torch.zeros(num_classes, num_classes) 
    self.C=self.C.cuda()

  def add(self, yp, yt):
    with torch.no_grad():
      self.C+=(yt*self.C.shape[1]+yp).bincount(minlength=self.C.numel()).view(self.C.shape).float()

  def clear(self):
    self.C.zero_()

  def acc(self):
    return self.C.diag().sum().item()/self.C.sum()

  def mAcc(self):
    return (self.C.diag()/self.C.sum(-1)).mean().item()

  def mIoU(self):
    return (self.C.diag()/(self.C.sum(0)+self.C.sum(1)-self.C.diag())).mean().item()

  def confusion_matrix(self):
    return self.C
    
def class_to_string (label):
  galaxy = {
      0: 'Barred Spiral', 
      1: 'Cigar Shaped Smooth', 
      2: 'Disturbed', 
      3: 'Edge-on with Bulge',
      4: 'Edge-on without Bulge', 
      5: 'In-between Round Smooth', 
      6: 'Merging', 
      7: 'Round Smooth', 
      8: 'Unbarred Loose Spiral', 
      9: 'Unbarred Tight Spiral'
    }
  return galaxy[label]

class galaxy_folder(torchvision.datasets.ImageFolder):
  
  def __getitem__(self, index):

      path, target = self.samples[index]
      sample = self.loader(path)
      if self.transform is not None:
          sample = self.transform(sample)
      if self.target_transform is not None:
         target = self.target_transform(target)

      return sample, target, path


def evaluate(yt, yp, num_classes=10):
  C=(yt*num_classes+yp).bincount(minlength=num_classes**2).view(num_classes,num_classes).float()
  return {
      'Acc': C.diag().sum().item()/yt.shape[0],
      'mAcc': (C.diag()/C.sum(-1)).mean().item(),
      'mIoU': (C.diag()/(C.sum(0)+C.sum(1)-C.diag())).mean().item()
  }

def validate(model, metric_tracker, dataloader):
  model.eval()
  metric_tracker.clear()

  with torch.no_grad(): 
    for i,(X,yt) in enumerate(dataloader):
      X = X.cuda()
      yt=yt.cuda()
      Y = model(X)
      y = Y.argmax(-1)

      metric_tracker.add(y,yt)

In [None]:
def train_one_epoch(model, loss_func, metric_tracker, dataloader, optimizer, epoch, tblog=None):
  model.train()
  metric_tracker.clear()  
  model=model.cuda()
  for i,(X,yt) in enumerate(dataloader):
    X=X.cuda()
    yt=yt.cuda()
    optimizer.zero_grad()
    Y = model(X)
    loss = loss_func(Y, yt)
    y = Y.argmax(-1)
    metric_tracker.add(y, yt)
    if tblog:
      wandb.log({'loss': loss.item()})
    loss.backward()
    optimizer.step()

def train(model, trDataLoader, vlDataLoader, optimizer, num_epochs, tblog=None):
  loss_func = nn.CrossEntropyLoss()
  metric_tracker = ClassificationMetrics(10)
  best_net=0
  model_save_name = 'secondTry.pt'
  path = F"pathtosavemodel/{model_save_name}" 
  for epoch in range(1,num_epochs+1):

    print("-- EPOCH {}/{} -------------------------\n".format(epoch, num_epochs))
    train_one_epoch(model, loss_func, metric_tracker, trDataLoader, optimizer, epoch, tblog)

    print("\tTRAIN | acc: {:.4f} | mAcc: {:.4f} | mIoU: {:.4f}".format(
        metric_tracker.acc(), metric_tracker.mAcc(), metric_tracker.mIoU()
    ))
    
    if tblog:
      wandb.log({'Train/acc': metric_tracker.acc()})
      wandb.log({'Train/mAcc': metric_tracker.mAcc()})
      wandb.log({'Train/mIoU': metric_tracker.mIoU()})

    validate(model, metric_tracker, vlDataLoader)

    print("\tEVAL  | acc: {:.4f} | mAcc: {:.4f} | mIoU: {:.4f}\n".format(
        metric_tracker.acc(), 
        metric_tracker.mAcc(), metric_tracker.mIoU()
    ))


    if tblog:
      wandb.log({'val/acc': metric_tracker.acc()})
      wandb.log({'val/mAcc': metric_tracker.mAcc()} )
      wandb.log({'val/mIoU': metric_tracker.mIoU()})

    #save the best performing net in all the epochs
    if metric_tracker.mIoU() >= best_net:
      torch.save(model.state_dict(), path)
      best_net = metric_tracker.mIoU()                                            con adam non dovrei dovermi preoccupare di cambiare il learning rate



In [None]:
def load_data():
  resized_normal = torchvision.transforms.Compose([
      torchvision.transforms.Resize(224),
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225])
  ])

  random_rot_transform = torchvision.transforms.Compose([
      torchvision.transforms.Resize(224),
      torchvision.transforms.RandomRotation((0, 360)),
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225])
  ])

  centrer_crop = torchvision.transforms.Compose([
      torchvision.transforms.CenterCrop(224),
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225])
  ])

  data_center_cropped = torchvision.datasets.ImageFolder(path_dataset_train, transform=centrer_crop)
  data_rand_rot = torchvision.datasets.ImageFolder(path_dataset_train, transform=random_rot_transform)
  data_resized = torchvision.datasets.ImageFolder(path_dataset_train, transform=resized_normal)

  tot_dataset = torch.utils.data.ConcatDataset([data_center_cropped, data_rand_rot, data_resized])


  n_image = 12415*3
  split = int(n_image*20/100)
  print('total datset: {}\ntraining set: {}\nvalidation set: {}'.format(n_image, n_image - split, split))
  # splitting the dataset in 80% for training and 20% for validation
  vl_dataset, tr_dataset = data.random_split(tot_dataset, [split, n_image - split])


  trDataLoader = torch.utils.data.DataLoader(tr_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
  vlDataLoader = torch.utils.data.DataLoader(vl_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

  return trDataLoader, vlDataLoader

In [None]:
#main

rete = models.resnet18(pretrained=True)
num_ftrs = rete.fc.in_features
rete.fc = nn.Linear(num_ftrs, 10)

trData, vlData = load_data()

optim = torch.optim.Adam(rete.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

train(rete, trData, vlData, optim, 20, "Resnet18")
