In [None]:
# import neccessary packages
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, Dataset, DataLoader
from model import DenseNet121

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

In [None]:
class ModelTrainer:
    def __init__(self, model, max_epochs=1, run_id='exp'):
        self.model = model.to(DEVICE)
        self.max_epochs = max_epochs
        self.run_id = run_id
        self.train_losses = []
        self.val_losses = []
        
        # Define your optimizer and criterion here
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=1, verbose=True)
        self.criterion = nn.CrossEntropyLoss()
        
    def train_val_epoch(self, train_loader, val_loader):
        self.model.train()
        self.epochs += 1
        epoch_loss = 0
        correct = 0
        
        for batch_num, (X, Y) in enumerate(train_loader):
            self.optimizer.zero_grad()
            
            X = X.to(DEVICE)
            
            out = self.model(X)        
            pred = out.data.max(1, keepdim=True)[1]
            predicted = pred.eq(Y.data.view_as(pred))
            correct += predicted.sum()
            loss = self.criterion(out, Y)
            
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.item()
      
        epoch_loss = epoch_loss / (batch_num + 1)
        train_acc = correct.cpu().numpy() / len(train_loader.dataset)
        
        val_loss, val_acc = self.validate(val_loader)
        self.scheduler.step(val_loss)
        
        print('[TRAIN]  Epoch [%d/%d]   Loss: %.4f   Accuracy: %.4f'
                      % (self.epochs, self.max_epochs, epoch_loss, train_acc))
        self.train_losses.append(epoch_loss)
        
        print('[VAL]  Epoch [%d/%d]   Loss: %.4f   Accuracy: %.4f'
                      % (self.epochs, self.max_epochs, val_loss, val_acc))
        self.val_losses.append(val_loss)
        
        
    def validate(self, val_loader):
        self.model.eval()
        val_loss = 0
        correct = 0
        
        for batch_num, (X, Y) in enumerate(val_loader):
            X = X.to(DEVICE)
            
            out = self.model(X)        
            pred = out.data.max(1, keepdim=True)[1]
            predicted = pred.eq(Y.data.view_as(pred))
            correct += predicted.sum()
            loss = self.criterion(out, Y)
            
            val_loss += loss.item()
        
        val_loss = val_loss / (batch_num + 1)
        acc = correct.cpu().numpy() / len(val_loader.dataset)
        
        return val_loss, acc
    
    def save(self):
        model_path = os.path.join('experiments', self.run_id, 'model-{}.pkl'.format(self.epochs))
        torch.save({'state_dict': self.model.state_dict()}, model_path)

### Setting hyperparameters

In [None]:
BATCH_SIZE = 16
NUM_EPOCHS = 5

### Load all data

In [1]:
all_dataset = ...
all_data_len = (all_dataset)

### Create folder to store results

In [None]:
run_id = str(int(time.time()))
if not os.path.exists('./experiments'):
    os.mkdir('./experiments')
os.mkdir('./experiments/%s' % run_id)
print("Saving models, predictions, and generated words to ./experiments/%s" % run_id)

### Load model 

In [None]:
model = DenseNet121()

# pretrained_dict = torch.load(...)
#     model_dict = model.state_dict()
#     # 1. filter out unnecessary keys
#     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
#     # 2. overwrite entries in the existing state dict
#     model_dict.update(pretrained_dict)
#     # 3. load the new state dict
#     model.load_state_dict(model_dict)

trainer = ModelTrainer(model=model, max_epochs=NUM_EPOCHS, run_id=run_id)

### Starting training, use cross validation

In [None]:
best_loss = 1e30

for epoch in range(NUM_EPOCHS):
    # split dataset into train and validation folder
    train_dataset, val_dataset = random_split(all_dataset, [all_data_len // 8, all_data_len // 2])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4)
    
    start_time = time.time()
    trainer.train_val_epoch()
    
    if self.val_loss[epoch] < best_loss:
        best_loss = self.val_loss[epoch]
        print("Saving model, predictions and generated output for epoch " + 
              str(epoch+1)+" with validation loss: " + str(best_loss))
        trainer.save()
        
    elapsed_time = time.time() - start_time
    print('Time elapsed: ', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
    

### Plot curve

In [None]:
plt.figure()
plt.plot(range(1, trainer.epochs + 1), trainer.train_losses, label='Training losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.figure()
plt.plot(range(1, trainer.epochs + 1), trainer.val_losses, label='Validation losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()