In [2]:
import torch
from tqdm.auto import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
#Create train function
def train_step(model , train_dataloader , loss_fn , optimizer , device = device):
    model.train()
    train_loss , train_acc = 0,0
    for batch , (x,y) in enumerate(train_dataloader):
        x , y = x.to(device) , y.to(device)
        #1.Feed Forward 
        y_pred = model(x)
        #2.Loss function
        loss = loss_fn(y_pred , y)
        train_loss+=loss.item()
        #3.Optimizer zero_grad
        optimizer.zero_grad()
        #4.Loss backward
        loss.backward()
        #5.Optimizer step
        optimizer.step()

        y_pred_class = torch.argmax(torch.softmax(y_pred , dim  = 1),dim = 1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)
    train_loss = train_loss/len(train_dataloader)
    train_acc =train_acc/len(train_dataloader)
    return train_loss , train_acc
def test_step(model , test_dataloader , loss_fn , device = device):
    test_loss , test_acc = 0,0
    model.eval()
    with torch.inference_mode():
        for batch , (x,y) in enumerate(test_dataloader):
            x , y = x.to(device) , y.to(device)
            #1.Feed Forward
            test_pred = model(x)
            #2.Loss function
            loss = loss_fn(test_pred , y)
            test_loss += loss.item()
    
            test_pred_labels = test_pred.argmax(dim = 1)
            test_acc += (test_pred_labels == y).sum().item()/len(test_pred_labels)
        test_loss = test_loss/len(test_dataloader)
        test_acc = test_acc/len(test_dataloader)
        return test_loss , test_acc
def train(model , train_dataloader, test_dataloader, loss_fn , optimizer ,epochs, device = device):
     results = {"train_loss":[],
              "train_acc":[],
              "test_loss":[],
              "test_acc":[]}
     for epoch in tqdm(range(epochs)):
        train_loss , train_acc = train_step(model , train_dataloader , loss_fn , optimizer , device)
        test_loss , test_acc = test_step(model , test_dataloader , loss_fn , device)
        print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"test_loss: {test_loss:.4f} | "
          f"test_acc: {test_acc:.4f}"
         )
      # Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)
     return results