In [None]:
import random
import matplotlib.pyplot as plt

def print_random_image(dataset, size=5):
    '''
    Displays a random image from PyTorch Dataset.
    
    -----
    Args:
        dataset (Dataset): PyTorch dataset with images, expected to return a pair (image (PyTorch tensor), label (str))
        size (int): figure size for matplotlib figure, affects the size of the displayed image, defaults to 5
    '''
    
    rand_ind = random.randint(0, len(dataset) - 1)
    rand_img = dataset[rand_ind]

    print('image label: ', rand_img[1]) 
    print()

    plt.figure(figsize=(size, size))
    plt.imshow(rand_img[0].permute(1, 2, 0))

In [None]:
'''
A transform to use with PyTorch's transforms.Compose(). Takes PyTorch tensor as input and pads it with 0 (black) to make a square sized image.

Usage example:
    transform_seg = transforms.Compose([
        SquarePadTensor(),
        transforms.Resize(224),
        transforms.Normalize([0.485, 0.456, 0.406], 
                             [0.229, 0.224, 0.225])
    ])
'''

class SquarePadTensor:
    def __call__(self, image):
        _, h, w = image.shape
        s = max(w, h)
        lft = (s - w) // 2
        rgt = s - w - lft
        top = (s - h) // 2
        bot = s - h - top

        padding = (lft, top, rgt, bot)
        return transforms.functional.pad(image, padding, 0, 'constant')

In [None]:
# normalization which was used in the majority of pretrained on ImageNet dataset models

transform = transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], 
                             [0.229, 0.224, 0.225])
    ])

In [None]:
# basic training pipeline for PyTorch

def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=10):
    epochs = num_epochs

    for epoch in tqdm(range(epochs)):
        model.train()
        epoch_loss = 0

        for data, label in tqdm(train_loader):
            data = data.to(device)
            label = label.to(device)
            
            optimizer.zero_grad()

            output = model(data).to(device)
            
            loss = criterion(output, label)
            loss.backward()
            
            optimizer.step()

            epoch_loss += loss / len(train_loader)

        print(f"Epoch : {epoch + 1}, train loss : {epoch_loss}")
        torch.save(model.state_dict(), Path("model_checkpoints") / f"fashion_model_checkpoint_epoch_{epoch + 1}.pt")


        model.eval()
        with torch.no_grad():
            epoch_val_loss = 0
            for data, label in tqdm(val_loader):
                data = data.to(device)
                label = label.to(device)

                val_output = model(data).to(device)
                val_loss = criterion(val_output, label)

                epoch_val_loss += val_loss / len(val_loader)

            print(f"Epoch : {epoch + 1}, val_loss : {epoch_val_loss}")
    return model