In [None]:
import torch

class Model_Training:
    def __init__(self, training_dataloader, validation_dataloader, model, loss_fn, optimizer, device):
        self.training_dataloader = training_dataloader
        self.validation_dataloader = validation_dataloader
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.device = device

        
    def train_loop(self):
        size = len(self.training_dataloader.dataset)
        num_batches = len(self.training_dataloader)
        total_loss = 0
        torch.cuda.empty_cache()
        for batch, (X, y) in enumerate(self.training_dataloader):
            # Compute prediction and loss
            X = X.to(self.device)
            y = y.to(self.device)
            #print(X.is_cuda)
            #print(y.is_cuda)
            pred = self.model(X)
            loss = self.loss_fn(pred, y)

            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            if batch % 100 == 0:
              loss, current = loss.item(), batch * len(X)
              print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        
        print(f"loss: {(total_loss/size):>7f}")
        return total_loss/size
                

    def validation_loop(self):
        size = len(self.validation_dataloader.dataset)
        num_batches = len(self.validation_dataloader)
        test_loss, correct = 0, 0

        with torch.no_grad():
            for X, y in self.validation_dataloader:
                X = X.to(self.device)
                y = y.to(self.device)
                
                pred = self.model(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        return test_loss
    