# Lighter training paper

### Setting DataLoader and Image Transformations
In order to be sure that everything is installed and working properly, we are going to download the dataset and simulate an inference cycle.

In [None]:
import torch
from torch import nn
from torchvision import datasets
import torchvision.transforms.v2 as transforms
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler

from Lighter import Lighter
from utils.transformations import ToGraph, NoisyImage

In [None]:
# loading MNIST
data_transforms = transforms.Compose([
  transforms.ToTensor(),
  transforms.Pad((0,0,2,2)),
  transforms.RandomRotation(degrees=(0,180)),
  NoisyImage(),
  transforms.RandomInvert(p=0.5),
  ToGraph()
])

train_dataset = datasets.MNIST(
  root="./data",
  train=True,
  download=True,
  transform=data_transforms
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)

test_dataset = datasets.MNIST(
  root="./data",
  train=False,
  download=True,
  transform=data_transforms
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1)

In [18]:
# selecting 10000 items for training and 1000 for testing
train_indeces = torch.Tensor().int()
test_indeces = torch.Tensor().int()
train_items_per_target = 500
test_items_per_target = 100

for target in train_dataset.targets.unique():
  target = target.item()
  target_indeces = torch.nonzero(train_dataset.targets == target).reshape(-1)
  train_indeces = torch.cat((train_indeces, target_indeces[:train_items_per_target]), dim=0)
  target_indeces = torch.nonzero(test_dataset.targets == target).reshape(-1)
  test_indeces = torch.cat((test_indeces, target_indeces[:test_items_per_target]), dim=0)

train_indeces = train_indeces[torch.randperm(train_indeces.size(0))]
test_indeces = test_indeces[torch.randperm(test_indeces.size(0))]


In [None]:
model = Lighter()

for batch in train_loader:
  with torch.no_grad():
    (X,A), L = batch
    output = model(X,A)
    print(output.shape)
    break
  

### Train Loop
We will use WandB to log the data obtained from the training process.

In [9]:
import wandb
import itertools
import inspect

In [None]:
wandb.login(relogin=True)

In [34]:
# train loop settings
log_enabled = False
restore_run = log_enabled and False
run_id = "mltxwb9p"
learning_rate = 1e-03
project_name = f"Lighter 6.0"
batch_size = 64

In [None]:
# device
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

# dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=RandomSampler(train_indeces))
test_loader = DataLoader(test_dataset, batch_size=batch_size,  sampler=RandomSampler(test_indeces))

model = Lighter().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
cross_entropy = nn.CrossEntropyLoss()

starting_epoch = 1

# restoring last run status
if restore_run:
  wandb.init(
    entity="lorenzocusin02",
    project="ML Project",
    id=run_id,
    resume="must"
  )
  api = wandb.Api()
  run = api.run(f"lorenzocusin02/ML Project/{run_id}")
  local_path = run.file("train_status.info").download(replace=True)
  print("> Model restored")
  checkpoint = torch.load("./train_status.info", map_location=torch.device('cpu'))
  model = model.to(device)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  starting_epoch = checkpoint['epoch']
elif log_enabled:
  wandb.init(
    project = "ML Project",
    name = project_name,
    notes = f"""
        LEARNING RATE: {learning_rate}
        class Lighter(nn.Module):
          {inspect.getsource(Lighter.__init__)}
          {inspect.getsource(Lighter.forward)}
    """
  )
  
for param_group in optimizer.param_groups:
  param_group['lr'] = learning_rate

train_loss = []
train_acc = []
test_loss = []
test_acc = []

print(f"> Device: {device}")
print(f"> Learning rate: {learning_rate}")

for epoch in itertools.count(start=starting_epoch):
  if (epoch % 10 == 0) and log_enabled:
    print("> Saving training status")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
      },
      "./train_status.info"
    )
    wandb.save("./train_status.info")

  print(f"> Epoch {epoch}")

  # training
  model.train()
  running_loss = 0
  running_acc = 0
  for i, batch in enumerate(train_loader):
    (features, adjency_matrix), labels = batch
    features = features.to(device)
    adjency_matrix = adjency_matrix.float().to(device)
    labels = labels.to(device)
    output = model(features, adjency_matrix)
    
    loss = cross_entropy(output, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    running_acc += (torch.max(output, 1).indices == labels).sum().item() / labels.shape[0]

  train_loss.append(running_loss / (i + 1))
  train_acc.append(running_acc / (i + 1) * 100)

  print(f"\tTrain Loss: {train_loss[-1]}")
  print(f"\tTrain Accuracy: {train_acc[-1]}")

  # evaluation
  model.eval()
  running_loss = 0
  running_acc = 0
  with torch.no_grad():
    for i, batch in enumerate(test_loader):
      (features, adjency_matrix), labels = batch
      features = features.to(device)
      adjency_matrix = adjency_matrix.float().to(device)
      labels = labels.to(device)
    
      output = model(features, adjency_matrix)
      
      loss = cross_entropy(output, labels)
    
      running_loss += loss.item()
      running_acc += (torch.max(output, 1).indices == labels).sum().item() / labels.shape[0]

    test_loss.append(running_loss / (i + 1))
    test_acc.append(running_acc / (i + 1) * 100)

  print(f"\tTest Loss: {test_loss[-1]}")
  print(f"\tTest Accuracy: {test_acc[-1]}")
  
  print()
  
  if log_enabled:
    wandb.log({
      "train_loss" : train_loss[-1],
      "train_acc" : train_acc[-1],
      "validation_loss" : test_loss[-1],
      "validation_acc" : test_acc[-1],
    })

  if (epoch % 10 == 0) and log_enabled:
    print("> Saving training status")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
      },
      "./train_status.info"
    )
    wandb.save("./train_status.info")

if log_enabled:
  wandb.finish()

In [828]:
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
  },
  "./train_status.info"
)