<a href="https://colab.research.google.com/github/Bish-Soli/MQP-LTR/blob/main/MQP_Losses.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset,random_split
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F
from __future__ import print_function
from sklearn.metrics import f1_score
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import os
from PIL import Image
import tarfile
import math

In [2]:
!pip install tensorboard_logger



In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class ImbalanceCIFAR(Dataset):
    def __init__(self, cifar_version=10, root='./data', train=True, transform=None, imbalance_ratio=0.01):
        super(ImbalanceCIFAR, self).__init__()
        if cifar_version == 10:
            self.original_dataset = torchvision.datasets.CIFAR10(root=root, train=train, download=True, transform=transform)
        elif cifar_version == 100:
            self.original_dataset = torchvision.datasets.CIFAR100(root=root, train=train, download=True, transform=transform)
        else:
            raise ValueError("CIFAR version must be 10 or 100")

        self.num_classes = 10 if cifar_version == 10 else 100
        self._create_long_tailed(imbalance_ratio)

    def _create_long_tailed(self, imbalance_ratio):
        # Get class distribution
        class_counts = np.bincount([label for _, label in self.original_dataset])
        # Compute number of samples for least represented class
        #num_samples_lt = [int(count * imbalance_ratio) for count in class_counts]
        # Compute number of samples for each class with exponential decrease
        max_count = max(class_counts)

        num_samples_lt = [int(max_count * (imbalance_ratio ** (i / (self.num_classes - 1.0)))) for i in range(self.num_classes)]
        self.indices = []
        self.targets = []
        for i in range(self.num_classes):
            class_indices = np.where(np.array(self.original_dataset.targets) == i)[0]
            np.random.shuffle(class_indices)
            selected_indices = class_indices[:num_samples_lt[i]]
            self.indices.extend(selected_indices)
            self.targets.extend([i] * len(selected_indices))
        np.random.shuffle(self.indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        return self.original_dataset[real_idx]

    def get_class_distribution(self):
        class_counts = np.bincount([self.original_dataset.targets[idx] for idx in self.indices])
        return {i: class_counts[i] for i in range(self.num_classes)}

    def get_class_names(self):
        return self.original_dataset.classes

    def plot_class_distribution(self):
        # distribution = self.get_class_distribution()
        # minority_class = min(distribution, key=distribution.get)  # Get the class with the minimum number of samples
        # bars = plt.bar(distribution.keys(), distribution.values())
        # # Highlight the minority class with a different color
        # bars[minority_class].set_color('y')  # Set to red or any color of your choice
        # plt.bar(distribution.keys(), distribution.values())
        # plt.xlabel('Class')
        # plt.ylabel('Number of samples')
        # plt.title('Class Distribution in Dataset')
        # plt.show()
        distribution = self.get_class_distribution()
        # Sort classes by the number of samples per class
        sorted_classes = sorted(distribution.items(), key=lambda item: item[1], reverse=True)

        # Separate the class indices and their corresponding counts
        sorted_indices, sorted_counts = zip(*sorted_classes)

        # Determine the threshold for minority classes, for example, you might define it as the lower 20%
        threshold = np.percentile(sorted_counts, 30)

        # Create a line plot for class distribution
        plt.plot(sorted_indices, sorted_counts, label='Class Distribution')

        # Fill the area under the curve
        plt.fill_between(sorted_indices, sorted_counts, where=(np.array(sorted_counts) <= threshold), color='red', alpha=0.5, label='Minority classes')
        plt.fill_between(sorted_indices, sorted_counts, where=(np.array(sorted_counts) > threshold), color='green', alpha=0.5, label='Majority classes')

        # Add labels and title
        plt.xlabel('Sorted class indices (Large → Small)')
        plt.ylabel('Training samples per class')
        plt.title('Class Distribution in Dataset')
        plt.legend()

        # Show the plot
        plt.show()

    def get_samples_from_each_class(self, num_samples=1):
        samples = {}
        for i in range(self.num_classes):
            class_indices = [idx for idx in self.indices if self.original_dataset.targets[idx] == i]
            np.random.shuffle(class_indices)
            samples[i] = [self.original_dataset[class_indices[j]] for j in range(num_samples)]
        return samples

    def imshow(img):
        img = img.numpy().transpose((1, 2, 0))  # Convert from tensor image
        mean = np.array([0.5, 0.5, 0.5])
        std = np.array([0.5, 0.5, 0.5])
        img = std * img + mean  # Unnormalize
        img = np.clip(img, 0, 1)  # Clip to [0, 1]
        plt.imshow(img)
        plt.show()

    def show_augmented_images(self, image_index, augmentations, num_samples=5):
        original_image, _ = self.original_dataset[image_index]
        images = [augmentations(original_image) for _ in range(num_samples)]
        grid_image = torchvision.utils.make_grid(images, nrow=num_samples)
        self.imshow(grid_image)


    def compute_imbalance_ratio(self):
        class_distribution = self.get_class_distribution()
        max_count = max(class_distribution.values())
        min_count = min(class_distribution.values())
        return max_count / min_count

    def get_random_batch(self, batch_size=32):
        indices = np.random.choice(self.indices, batch_size, replace=False)
        return [self.original_dataset[idx] for idx in indices]

    def extract_features_and_labels(dataset):
        features = []
        labels = []

        for img, label in dataset:
        # Flatten the image and convert to numpy array
          flattened_img = torch.flatten(img).numpy()
          features.append(flattened_img)
          labels.append(label)

        return np.array(features), np.array(labels)

    def visualize_with_tsne(features, labels, class_names):
        tsne = TSNE(n_components=2, random_state=123)
        tsne_results = tsne.fit_transform(features)

        plt.figure(figsize=(10, 8))
        for i, class_name in enumerate(class_names):
            indices = labels == i
            plt.scatter(tsne_results[indices, 0], tsne_results[indices, 1], label=class_name)
        plt.legend()
        plt.title('t-SNE visualization of the dataset')
        plt.xlabel('t-SNE feature 1')
        plt.ylabel('t-SNE feature 2')
        plt.show()

    def visualize_with_pca(features, labels, class_names):
        pca = PCA(n_components=2)
        pca_results = pca.fit_transform(features)

        plt.figure(figsize=(10, 8))
        for i, class_name in enumerate(class_names):
          indices = labels == i
          plt.scatter(pca_results[indices, 0], pca_results[indices, 1], label=class_name)
        plt.legend()
        plt.title('PCA visualization of the dataset')
        plt.xlabel('PCA feature 1')
        plt.ylabel('PCA feature 2')
        plt.show()






class ImbalanceCIFAR10(ImbalanceCIFAR):
    def __init__(self, root='./data', train=True, transform=None, imbalance_ratio=0.1):
        super(ImbalanceCIFAR10, self).__init__(cifar_version=10, root=root, train=train, transform=transform, imbalance_ratio=imbalance_ratio)


class ImbalanceCIFAR100(ImbalanceCIFAR):
    def __init__(self, root='./data', train=True, transform=None, imbalance_ratio=0.1):
        super(ImbalanceCIFAR100, self).__init__(cifar_version=100, root=root, train=train, transform=transform, imbalance_ratio=imbalance_ratio)

# Define the image transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Initialize the datasets
cifar10_lt = ImbalanceCIFAR10(train=True, transform=transform, imbalance_ratio=0.1)
cifar100_lt = ImbalanceCIFAR100(train=True, transform=transform, imbalance_ratio=0.02)
cifar10_test = ImbalanceCIFAR10(train=False, transform=test_transform, imbalance_ratio=1)  # imbalance_ratio=1 to keep original distribution
cifar100_test = ImbalanceCIFAR100(train=False, transform=test_transform, imbalance_ratio=1)  # imbalance_ratio=1 to keep original distribution






cpu
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
class_distribution = cifar100_lt.get_class_distribution()  # Use the appropriate dataset instance
class_counts = list(class_distribution.values())

In [5]:
inverse_freq = [1.0 / count for count in class_counts]
sum_inv_freq = sum(inverse_freq)
normalized_weights = [x / sum_inv_freq for x in inverse_freq]

In [6]:
weights_tensor = torch.tensor(normalized_weights).to(device)

In [7]:


class MixupLTDataloader():
  def __init__(self, batch_size, dataset, alpha=0.2):
    self.n_samples = len(dataset)
    self.n_classes = dataset.num_classes
    self.batch_size = batch_size
    self.alpha = alpha
    self.n_iterations = math.ceil(self.n_samples / batch_size)
    self.labels = np.array(dataset.targets)
    self.dataset = dataset

  def random_class(self):
    return np.random.randint(self.n_classes)

  def random_sample(self, random_class):
    # Randomly get index from the potential indices in the class
    indices = np.where(self.labels == random_class)[0] # this is a numpy array
    rand_idx = np.random.randint(len(indices)+1)
    rand_sample, rand_label = self.dataset.__getitem__(rand_idx)
    return rand_sample , rand_label

  def generate_sample(self):
    rand_class = self.random_class()
    return self.random_sample(rand_class)

  def generate_mixup(self):
    # Define two samplers
    probability = torch.zeros(self.num_classes)

    x1, y1 = self.generate_sample()
    x2, y2 = self.generate_sample()

    x_hat = self.alpha * x1 + (1 - self.alpha) * x2
    y_hat = self.alpha * y1 + (1 - self.alpha) * y2

    c1 = math.floor(y_hat)
    decimal = y_hat - c1
    if c1 != self.num_classes -1:
      c2 = c1 + 1
    probability[c1] = 1 - decimal
    probability[c2] = decimal

    # y_hat = round(y_hat)
    # y_hat = torch.LongTensor([y_hat])
    # return x_hat, y_hat
    return x_hat , torch.FloatTensor(probability)

  def generate_batch(self):
    data = []
    labels = []
    while len(data) < self.batch_size:
      x, y = self.generate_mixup()
      data.append(x)
      labels.append(y)

    return torch.stack(data), torch.stack(labels)

In [8]:
mixup = MixupLTDataloader(dataset=cifar100_lt,batch_size=4)

In [9]:
"""ResNet in PyTorch.
ImageNet-Style ResNet
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
Adapted from: https://github.com/bearpaw/pytorch-classification
"""


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, layer=100):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out


def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)


model_dict = {
    'resnet18': [resnet18, 512],
    'resnet34': [resnet34, 512],
    'resnet50': [resnet50, 2048],
    'resnet101': [resnet101, 2048],
}


class LinearBatchNorm(nn.Module):
    """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
    def __init__(self, dim, affine=True):
        super(LinearBatchNorm, self).__init__()
        self.dim = dim
        self.bn = nn.BatchNorm2d(dim, affine=affine)

    def forward(self, x):
        x = x.view(-1, self.dim, 1, 1)
        x = self.bn(x)
        x = x.view(-1, self.dim)
        return x


class SupConResNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, name='resnet50', head='mlp', feat_dim=128, num_classes = 10):
        super(SupConResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        self.num_classes = num_classes
        self.fc = nn.Linear(feat_dim, num_classes)
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):

        feat = self.encoder(x)
        feat = F.normalize(self.head(feat), dim=1)
        bsz = int(feat.shape[0]/2)
        f1, f2 = torch.split(feat, [bsz, bsz], dim=0)
        x = self.fc(f1)
        return feat, x


#Cross entropy model
class SupCEResNet(nn.Module):
    """encoder + classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(SupCEResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        self.fc = nn.Linear(dim_in, num_classes)

    def forward(self, x):
        return self.fc(self.encoder(x))


class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(LinearClassifier, self).__init__()
        _, feat_dim = model_dict[name]
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, features):
        return self.fc(features)

In [10]:
class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def adjust_learning_rate(args, optimizer, epoch):
    lr = args.learning_rate
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
    if args.warm and epoch <= args.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (args.warm_epochs * total_batches)
        lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def set_optimizer(opt, model):
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    return optimizer


def save_model(model, optimizer, opt, epoch, save_file):
    print('==> Saving...')
    state = {
        'opt': opt,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [12]:
from typing import Optional, Sequence

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F


class FocalLoss(nn.Module):
    """ Focal Loss, as described in https://arxiv.org/abs/1708.02002.

    It is essentially an enhancement to cross entropy loss and is
    useful for classification tasks when there is a large class imbalance.
    x is expected to contain raw, unnormalized scores for each class.
    y is expected to contain class labels.

    Shape:
        - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
        - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
    """

    def __init__(self,
                 alpha: Optional[Tensor] = None,
                 gamma: float = 0.,
                 reduction: str = 'mean',
                 ignore_index: int = -100):
        """Constructor.

        Args:
            alpha (Tensor, optional): Weights for each class. Defaults to None.
            gamma (float, optional): A constant, as described in the paper.
                Defaults to 0.
            reduction (str, optional): 'mean', 'sum' or 'none'.
                Defaults to 'mean'.
            ignore_index (int, optional): class label to ignore.
                Defaults to -100.
        """
        if reduction not in ('mean', 'sum', 'none'):
            raise ValueError(
                'Reduction must be one of: "mean", "sum", "none".')

        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction

        self.nll_loss = nn.NLLLoss(
            weight=alpha, reduction='none', ignore_index=ignore_index)

    def __repr__(self):
        arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction']
        arg_vals = [self.__dict__[k] for k in arg_keys]
        arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)]
        arg_str = ', '.join(arg_strs)
        return f'{type(self).__name__}({arg_str})'

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        if x.ndim > 2:
            # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
            c = x.shape[1]
            x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
            # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
            y = y.view(-1)

        unignored_mask = y != self.ignore_index
        y = y[unignored_mask]
        if len(y) == 0:
            return torch.tensor(0.)
        x = x[unignored_mask]

        # compute weighted cross entropy term: -alpha * log(pt)
        # (alpha is already part of self.nll_loss)
        log_p = F.log_softmax(x, dim=-1)
        ce = self.nll_loss(log_p, y)

        # get true class column from each row
        all_rows = torch.arange(len(x))
        log_pt = log_p[all_rows, y]

        # compute focal term: (1 - pt)^gamma
        pt = log_pt.exp()
        focal_term = (1 - pt)**self.gamma

        # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
        loss = focal_term * ce

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()

        return loss


def focal_loss(alpha: Optional[Sequence] = None,
               gamma: float = 0.,
               reduction: str = 'mean',
               ignore_index: int = -100,
               device='cpu',
               dtype=torch.float32) -> FocalLoss:
    """Factory function for FocalLoss.

    Args:
        alpha (Sequence, optional): Weights for each class. Will be converted
            to a Tensor if not None. Defaults to None.
        gamma (float, optional): A constant, as described in the paper.
            Defaults to 0.
        reduction (str, optional): 'mean', 'sum' or 'none'.
            Defaults to 'mean'.
        ignore_index (int, optional): class label to ignore.
            Defaults to -100.
        device (str, optional): Device to move alpha to. Defaults to 'cpu'.
        dtype (torch.dtype, optional): dtype to cast alpha to.
            Defaults to torch.float32.

    Returns:
        A FocalLoss object
    """
    if alpha is not None:
        if not isinstance(alpha, Tensor):
            alpha = torch.tensor(alpha)
        alpha = alpha.to(device=device, dtype=dtype)

    fl = FocalLoss(
        alpha=alpha,
        gamma=gamma,
        reduction=reduction,
        ignore_index=ignore_index)
    return fl

In [13]:
from __future__ import print_function

import sys
import argparse
import time
import math

import tensorboard_logger as tb_logger

import torch.backends.cudnn as cudnn


import argparse
import os
import math


def parse_option(batch_size=32, epochs=10, print_freq=10, save_freq=9, num_workers=16,
                 learning_rate=0.05, lr_decay_epochs='700,800,900', lr_decay_rate=0.1,
                 weight_decay=1e-4, momentum=0.9, model='resnet50', dataset='cifar100',
                 mean=None, std=None, data_folder=None, size=32, method='SupCon',
                 temp=0.07, cosine=False, syncBN=False, warm=False, trial='0'):
    class Option(object):
        pass

    opt = Option()

    # Set values
    opt.batch_size = batch_size
    opt.epochs = epochs
    opt.print_freq = print_freq
    opt.save_freq = save_freq
    opt.num_workers = num_workers
    opt.learning_rate = learning_rate
    opt.lr_decay_epochs = lr_decay_epochs
    opt.lr_decay_rate = lr_decay_rate
    opt.weight_decay = weight_decay
    opt.momentum = momentum
    opt.model = model
    opt.dataset = dataset
    opt.mean = mean
    opt.std = std
    opt.data_folder = data_folder
    opt.size = size
    opt.method = method
    opt.temp = temp
    opt.cosine = cosine
    opt.syncBN = syncBN
    opt.warm = warm
    opt.trial = trial



    # Additional processing as in the original function
    if opt.dataset == 'path':
        assert opt.data_folder is not None and opt.mean is not None and opt.std is not None

    if opt.data_folder is None:
        opt.data_folder = './datasets/'

    opt.model_path = './save/SupCon/{}_models'.format(opt.dataset)
    opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset)

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = [int(it) for it in iterations]

    opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.\
        format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
               opt.batch_size, opt.trial)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.batch_size > 256:
        opt.warm = True
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate

    opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
    if not os.path.isdir(opt.tb_folder):
        os.makedirs(opt.tb_folder)

    opt.save_folder = os.path.join(opt.model_path, opt.model_name)
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)

    if opt.dataset == 'cifar10':
        opt.n_cls = 10
    elif opt.dataset == 'cifar100':
        opt.n_cls = 100
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    return opt


def set_loader(opt):
    # construct data loader
    if opt.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif opt.dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    if opt.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=opt.data_folder,
                                         transform=train_transform,
                                         download=True)
        val_dataset = datasets.CIFAR10(root=opt.data_folder,
                                       train=False,
                                       transform=val_transform)
    elif opt.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=opt.data_folder,
                                          transform=train_transform,
                                          download=True)
        val_dataset = datasets.CIFAR100(root=opt.data_folder,
                                        train=False,
                                        transform=val_transform)
    else:
        raise ValueError(opt.dataset)

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
        num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    return train_loader, val_loader


def set_model(opt,use_CE= True):
    model = SupCEResNet(name=opt.model, num_classes=opt.n_cls)
    if use_CE:
      criterion = torch.nn.CrossEntropyLoss()
    else:
       criterion = focal_loss(alpha=weights_tensor, gamma=2.0, reduction='mean')

    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion


def train(train_loader, model, criterion, optimizer, epoch, opt):
    """one epoch training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()
    iterations = math.ceil(mixup.n_samples / mixup.batch_size)

    for idx in range(iterations):
        images , labels = mixup.generate_batch()
        data_time.update(time.time() - end)
        images = images.to(device)
        lables = labels.to(device)

        # images = images.cuda(non_blocking=True)
        # labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        output = model(images)
        loss = criterion(output, labels.squeeze())

        # update metric
        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))
            sys.stdout.flush()

    return losses.avg, top1.avg


def validate(val_loader, model, criterion, opt):
    """validation"""
    model.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]

            # forward
            output = model(images)
            loss = criterion(output, labels)

            # update metric
            losses.update(loss.item(), bsz)
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            top1.update(acc1[0], bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                       idx, len(val_loader), batch_time=batch_time,
                       loss=losses, top1=top1))

    print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
    return losses.avg, top1.avg


def main():
    best_acc = 0
    opt = parse_option()

    # build data loader
    # train_loader, val_loader = train_loader_cifar100_lt

    # build model and criterion
    model, criterion = set_model(opt,False)

    # build optimizer
    optimizer = set_optimizer(opt, model)

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss, train_acc = train(train_loader_cifar100_lt, model, criterion, optimizer, epoch, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('train_loss', loss, epoch)
        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        # evaluation
        loss, val_acc = validate(val_loader_cifar100_lt, model, criterion, opt)
        logger.log_value('val_loss', loss, epoch)
        logger.log_value('val_acc', val_acc, epoch)

        if val_acc > best_acc:
            best_acc = val_acc

        if epoch % opt.save_freq == 0:
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, opt, epoch, save_file)

    # save the last model
    save_file = os.path.join(
        opt.save_folder, 'last.pth')
    save_model(model, optimizer, opt, opt.epochs, save_file)

    print('best accuracy: {:.2f}'.format(best_acc))


if __name__ == '__main__':
    main()

NameError: name 'train_loader_cifar100_lt' is not defined