In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                print("training in progress")
                model.train()  # Set model to training mode
            else:
                print("validating in progress")
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            counter = 0
            hund_loss = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                #time to carry out the forward training poss
                with torch.set_grad_enabled(phase == 'train'):
                    counter += 1
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    hund_loss += loss.item()
                    if counter % 100 == 0:
                        print(f'{phase} Round {epoch} of {num_epochs} loss: {(hund_loss / 100)} Running corrects: {running_corrects} Batch No: {counter}')
                        writer.add_scalar('Training Loss', (hund_loss / 100), time.time())
                        hund_loss = 0

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                #variables to hold the loss/acc statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()

            #saving variable for plottin
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / (counter * 4) 
            writer.add_scalar('Epoch Accuracy', epoch_acc, time.time())

            print(f'Phase: {phase} -- Loss: {epoch_loss:.4f} -- Acc: {(epoch_acc * 100):.1f}%')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                # save the model to file
                

        print()

    time_elapsed = time.time() - since
    print(f'Training model took {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Overall Val Acc is: {(best_acc * 100):.1f}%')
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), '/home/shah/Desktop/FB-Marketplace-Recommendation-Ranking-System/data/model_state_dict.pt')
    return model