In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.distributions.beta import Beta
from PIL import Image
from torch.autograd import Variable
from torchvision.models import resnet50, ResNet50_Weights

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.classifier = nn.Linear(2048, 10)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
class Mixup(nn.Module):
    '''
    This is a class for performing Mixup augmentation. It takes in a batch of input data x and its corresponding labels y. The alpha parameter determines the strength of the augmentation, and the sampling_method parameter selects the method to use for sampling mixup coefficients.

    The __call__ method applies the mixup augmentation to the input data and labels, returning the mixed input data mixed_x, and the corresponding mixed labels y_a and y_b, along with the mixup coefficient lam.

    Inputs:

        x (torch.Tensor) - Input tensor of shape (batch_size, *), where * is any number of additional dimensions.
        y (torch.Tensor) - Label tensor of shape (batch_size,).
        alpha (float, optional) - Strength of the mixup augmentation. Default is 1.0.
        sampling_method (int, optional) - Method to use for sampling mixup coefficients. Must be 1 or 2. Default is 1.
        range_min (float, optional): Minimum range of the uniform distribution used to sample mixup coefficient when using sampling_method 2. Default is 0.2.
        range_max (float, optional): Maximum range of the uniform distribution used to sample mixup coefficient when using sampling_method 2. Default is 0.8.

    Outputs:

        mixed_x (torch.Tensor) - Tensor of shape (batch_size, *) containing the mixed input data.
        y_a (torch.Tensor) - Tensor of shape (batch_size,) containing the original labels.
        y_b (torch.Tensor) - Tensor of shape (batch_size,) containing the mixed labels.
        lam (float) - Mixup coefficient used for mixing the input data and labels.
    '''
    def __init__(self, alpha=1.0, sampling_method=1, range_min=0.2, range_max= 0.8):
        self.alpha = alpha
        self.sampling_method = sampling_method
        self.range_min = range_min
        self.range_max = range_max

    def __call__(self, x, y):
        batch_size = x.size()[0]
        
        if self.sampling_method == 1:
            lam = np.random.beta(self.alpha, self.alpha)
        elif self.sampling_method == 2:
            lam = np.random.uniform(self.range_min, self.range_max)
        else:
            raise ValueError(f"Invalid sampling method: {self.sampling_method}")

        indices = torch.randperm(batch_size).to(x.device)
        mixed_x = lam * x + (1 - lam) * x[indices, :]
        y_a, y_b = y, y[indices]
        return mixed_x, y_a, y_b, lam