In [1]:
# Create directories for saving model checkpoints
!mkdir -p checkpoint/ours/pretrain/
!mkdir -p ../cifar-10 # for downloading dataset

# Install dependencies
print("\n--- Installing dependencies ---")
!pip install torch==2.0.0 torchvision==0.15.1 tqdm -q
print("--> Dependencies installed.")


--- Installing dependencies ---
--> Dependencies installed.


In [2]:
!pip install "numpy<2.0"



In [3]:
%%writefile resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.autograd import Variable
import torch.nn.init as init

def to_var(x, requires_grad=True):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, requires_grad=requires_grad)

class resnet_attention(nn.Module):
    def __init__(self, enc_hid_dim=64, dec_hid_dim=100):
        super(resnet_attention, self).__init__()
        self.attn = nn.Linear(enc_hid_dim , dec_hid_dim, bias=True)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)
    def forward(self, s):
        energy = torch.tanh(self.attn(s))
        attention = self.v(energy)
        return  F.softmax(attention, dim=0)

class MetaModule(nn.Module):
    def params(self):
        for name, param in self.named_params(self):
            yield param
    def named_leaves(self):
        return []
    def named_submodules(self):
        return []
    def named_params(self, curr_module=None, memo=None, prefix=''):
        if memo is None:
            memo = set()
        if hasattr(curr_module, 'named_leaves'):
            for name, p in curr_module.named_leaves():
                if p is not None and p not in memo:
                    memo.add(p)
                    yield prefix + ('.' if prefix else '') + name, p
        else:
            for name, p in curr_module._parameters.items():
                if p is not None and p not in memo:
                    memo.add(p)
                    yield prefix + ('.' if prefix else '') + name, p
        for mname, module in curr_module.named_children():
            submodule_prefix = prefix + ('.' if prefix else '') + mname
            for name, p in self.named_params(module, memo, submodule_prefix):
                yield name, p
    def update_params(self, lr_inner, first_order=False, source_params=None, detach=False):
        if source_params is not None:
            for tgt, src in zip(self.named_params(self), source_params):
                name_t, param_t = tgt
                grad = src
                if first_order:
                    grad = to_var(grad.detach().data)
                tmp = param_t - lr_inner * grad
                self.set_param(self, name_t, tmp)
        else:
            for name, param in self.named_params(self):
                if not detach:
                    grad = param.grad
                    if first_order:
                        grad = to_var(grad.detach().data)
                    tmp = param - lr_inner * grad
                    self.set_param(self, name, tmp)
                else:
                    param = param.detach_()
                    self.set_param(self, name, param)
    def set_param(self, curr_mod, name, param):
        if '.' in name:
            n = name.split('.')
            module_name = n[0]
            rest = '.'.join(n[1:])
            for name, mod in curr_mod.named_children():
                if module_name == name:
                    self.set_param(mod, rest, param)
                    break
        else:
            setattr(curr_mod, name, param)
    def detach_params(self):
        for name, param in self.named_params(self):
            self.set_param(self, name, param.detach())
    def copy(self, other, same_var=False):
        for name, param in other.named_params():
            if not same_var:
                param = to_var(param.data.clone(), requires_grad=True)
            self.set_param(name, param)

class MetaLinear(MetaModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        ignore = nn.Linear(*args, **kwargs)
        self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
        self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
    def forward(self, x):
        return F.linear(x, self.weight, self.bias)
    def named_leaves(self):
        return [('weight', self.weight), ('bias', self.bias)]

class MetaConv2d(MetaModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        ignore = nn.Conv2d(*args, **kwargs)
        self.in_channels = ignore.in_channels
        self.out_channels = ignore.out_channels
        self.stride = ignore.stride
        self.padding = ignore.padding
        self.dilation = ignore.dilation
        self.groups = ignore.groups
        self.kernel_size = ignore.kernel_size
        self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
        if ignore.bias is not None:
            self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
        else:
            self.register_buffer('bias', None)
    def forward(self, x):
        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
    def named_leaves(self):
        return [('weight', self.weight), ('bias', self.bias)]

class MetaBatchNorm2d(MetaModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        ignore = nn.BatchNorm2d(*args, **kwargs)
        self.num_features = ignore.num_features
        self.eps = ignore.eps
        self.momentum = ignore.momentum
        self.affine = ignore.affine
        self.track_running_stats = ignore.track_running_stats
        if self.affine:
            self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
            self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(self.num_features))
            self.register_buffer('running_var', torch.ones(self.num_features))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
    def forward(self, x):
        return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
                            self.training or not self.track_running_stats, self.momentum, self.eps)
    def named_leaves(self):
        return [('weight', self.weight), ('bias', self.bias)]

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(MetaModule):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

class BasicBlock(MetaModule):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = MetaBatchNorm2d(planes)
        self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = MetaBatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     MetaBatchNorm2d(self.expansion * planes)
                )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet32(MetaModule):
    def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]):
        super(ResNet32, self).__init__()
        self.in_planes = 16
        self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = MetaBatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = MetaLinear(64, num_classes)
        self.apply(_weights_init)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        y = self.linear(out)
        return out, y

print("File 'resnet.py' has been saved.")

Overwriting resnet.py


In [4]:
%%writefile data_utils.py
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
import numpy as np
import copy
from torch.utils.data import Dataset

np.random.seed(6)

def build_dataset(dataset,num_meta):
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()),
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    if dataset == 'cifar10':
        train_dataset = torchvision.datasets.CIFAR10(root='../cifar-10', train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10('../cifar-10', train=False, transform=transform_test)
        img_num_list = [num_meta] * 10
        num_classes = 10
    if dataset == 'cifar100':
        train_dataset = torchvision.datasets.CIFAR100(root='../cifar-100', train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100('../cifar-100', train=False, transform=transform_test)
        img_num_list = [num_meta] * 100
        num_classes = 100
    data_list_val = {}
    for j in range(num_classes):
        data_list_val[j] = [i for i, label in enumerate(train_dataset.targets) if label == j]
    idx_to_meta = []
    idx_to_train = []
    for cls_idx, img_id_list in data_list_val.items():
        np.random.shuffle(img_id_list)
        img_num = img_num_list[int(cls_idx)]
        idx_to_meta.extend(img_id_list[:img_num])
        idx_to_train.extend(img_id_list[img_num:])
    train_data = copy.deepcopy(train_dataset)
    train_data_meta = copy.deepcopy(train_dataset)
    train_data_meta.data = np.delete(train_dataset.data, idx_to_train,axis=0)
    train_data_meta.targets = np.delete(train_dataset.targets, idx_to_train, axis=0)
    train_data.data = np.delete(train_dataset.data, idx_to_meta, axis=0)
    train_data.targets = np.delete(train_dataset.targets, idx_to_meta, axis=0)
    return train_data_meta, train_data, test_dataset

def get_img_num_per_cls(dataset, imb_factor=None, num_meta=None):
    if dataset == 'cifar10':
        img_max = (50000-num_meta)/10
        cls_num = 10
    if dataset == 'cifar100':
        img_max = (50000-num_meta)/100
        cls_num = 100
    if imb_factor is None:
        return [int(img_max)] * cls_num
    img_num_per_cls = []
    for cls_idx in range(cls_num):
        num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
        img_num_per_cls.append(int(num))
    return img_num_per_cls

class new_dataset(Dataset):
    def __init__(self, dataset, train=None):
        self.data = dataset.data
        self.targets = dataset.targets
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        if train:
            self.transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),(4, 4, 4, 4), mode='reflect').squeeze()),
                                transforms.ToPILImage(),
                                transforms.RandomCrop(32),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                normalize,
                            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                normalize
            ])
    def __getitem__(self, index):
        img, label = self.data[index, ::], self.targets[index]
        img = self.transform(img)
        label = torch.LongTensor([np.int64(label)])
        return img, label, index
    def __len__(self):
        return len(self.data)

print("File 'data_utils.py' has been saved.")

Overwriting data_utils.py


In [5]:
%%writefile Sinkhorn_distance.py
import torch
import torch.nn as nn
from torch.autograd.variable import Variable

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
d_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8)

class SinkhornDistance(nn.Module):
    def __init__(self, eps, max_iter, dis, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction
        self.dis = dis
    def forward(self, x, y, nu):
        if self.dis == 'cos':
            C = self._cost_matrix(x, y, 'cos')
        elif self.dis == 'euc':
            C = self._cost_matrix(x, y, 'euc')
        x_points = x.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]
        mu = torch.empty(batch_size, x_points, dtype=torch.float, requires_grad=False).fill_(1.0 / x_points).to(device).squeeze()
        u = torch.zeros_like(mu).to(device)
        v = torch.zeros_like(nu).to(device)
        for i in range(self.max_iter):
            u1 = u
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()
            if err.item() < 1e-1:
                break
        U, V = u, v
        pi = torch.exp(self.M(C, U, V))
        cost = torch.sum(pi * C, dim=(-2, -1))
        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()
        return cost
    def M(self, C, u, v):
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
    @staticmethod
    def _cost_matrix(x, y, dis, p=2):
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        if dis == 'cos':
            C = 1 - d_cosine(x_col, y_lin)
        elif dis == 'euc':
            C = torch.mean((torch.abs(x_col - y_lin)) ** p, -1)
        return C
print("File 'Sinkhorn_distance.py' has been saved.")

Overwriting Sinkhorn_distance.py


In [6]:
%%writefile Sinkhorn_distance_fl.py
import torch
import torch.nn as nn
from torch.autograd.variable import Variable

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
d_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8)

class SinkhornDistance(nn.Module):
    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction
    def forward(self, x, y, x1, y1, nu):
        C1 = self._cost_matrix(x, y, dis='cos')
        C2 = self._cost_matrix(x1, y1, dis='euc')
        C = 0.5*C1 + 0.5*C2
        x_points = x.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]
        mu = torch.empty(batch_size, x_points, dtype=torch.float, requires_grad=False).fill_(1.0 / x_points).to(device).squeeze()
        u = torch.zeros_like(mu).to(device)
        v = torch.zeros_like(nu).to(device)
        for i in range(self.max_iter):
            u1 = u
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()
            if err.item() < 1e-1:
                break
        U, V = u, v
        pi = torch.exp(self.M(C, U, V))
        cost = torch.sum(pi * C, dim=(-2, -1))
        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()
        return cost
    def M(self, C, u, v):
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
    @staticmethod
    def _cost_matrix(x, y, dis, p=2):
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        if dis == 'cos':
            C = 1 - d_cosine(x_col , y_lin)
        elif dis == 'euc':
            C = torch.mean((torch.abs(x_col - y_lin)) ** p, -1)
        return C
print("File 'Sinkhorn_distance_fl.py' has been saved.")

Overwriting Sinkhorn_distance_fl.py


In [7]:
%%writefile pretrain_stage1.py
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import os
from tqdm import tqdm
import random
import copy
import numpy as np

# Import from the scripts you have
from data_utils import build_dataset, get_img_num_per_cls, new_dataset
from resnet import ResNet32
from torch.utils.data import DataLoader

def get_args():
    parser = argparse.ArgumentParser(description='Stage 1 Pre-training')
    parser.add_argument('--dataset', default='cifar10', type=str)
    parser.add_argument('--imb_factor', default=0.01, type=float)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=200) # Paper uses 200 epochs
    parser.add_argument('--lr', default=0.1, type=float)
    parser.add_argument('--gpu', default=0, type=int)
    return parser.parse_args()

def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_meta = 10 # From your OT_train.py
    num_classes = 10 if args.dataset == 'cifar10' else 100

    # 1. Create Imbalanced Dataset
    print("--> Creating Imbalanced Dataset...")
    _, train_data, test_dataset = build_dataset(args.dataset, num_meta)
    img_num_list = get_img_num_per_cls(args.dataset, args.imb_factor, num_meta * num_classes)
    data_list = {j: [i for i, label in enumerate(train_data.targets) if label == j] for j in range(num_classes)}
    idx_to_del = []
    for cls_idx, img_id_list in data_list.items():
        random.shuffle(img_id_list)
        img_num = img_num_list[int(cls_idx)]
        idx_to_del.extend(img_id_list[img_num:])
    imbalanced_train_dataset = copy.deepcopy(train_data)
    imbalanced_train_dataset.targets = np.delete(train_data.targets, idx_to_del, axis=0)
    imbalanced_train_dataset.data = np.delete(train_data.data, idx_to_del, axis=0)
    imbalanced_train_loader = DataLoader(new_dataset(imbalanced_train_dataset, train=True), batch_size=args.batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(new_dataset(test_dataset, train=False), batch_size=100, shuffle=False, num_workers=2)

    # 2. Build and Train Model
    print("--> Building and Training Model for Stage 1...")
    model = ResNet32(num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()

    
    optimizer = optim.SGD(model.params(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
   

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[160, 180], gamma=0.1)

    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(imbalanced_train_loader, desc=f'Epoch {epoch+1}/{args.epochs}')
        for inputs, labels, _ in pbar:
            inputs, labels = inputs.to(device), labels.to(device).squeeze()
            _, outputs = model(inputs) # Your ResNet returns (features, logits)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix({'Loss': f'{loss.item():.3f}'})
        scheduler.step()

    # 3. Save Checkpoint for OT_train.py
    save_dir = 'checkpoint/ours/pretrain/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = f'{save_dir}/cifar10_imb{args.imb_factor}_stage1.pth'
    print(f"--> Stage 1 training complete. Saving model to {save_path}")
    torch.save({'state_dict': model.state_dict()}, save_path)

if __name__ == '__main__':
    main()

print("File 'pretrain_stage1.py' has been saved.")

Overwriting pretrain_stage1.py


In [8]:
%%writefile OT_train.py
import os
import time
import argparse
import random
import copy
import torch
import torchvision
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from data_utils import *
from resnet import *
import shutil
from Sinkhorn_distance import SinkhornDistance
from Sinkhorn_distance_fl import SinkhornDistance as SinkhornDistance_fl
from torch.utils.data import TensorDataset, DataLoader
import torch.backends.cudnn as cudnn

parser = argparse.ArgumentParser(description='Imbalanced Example')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--cost', default='combined', type=str)
parser.add_argument('--meta_set', default='prototype', type=str)
parser.add_argument('--batch-size', type=int, default=16, metavar='N')
parser.add_argument('--num_classes', type=int, default=10)
parser.add_argument('--num_meta', type=int, default=10)
parser.add_argument('--imb_factor', type=float, default=0.005)
parser.add_argument('--epochs', type=int, default=250, metavar='N')
parser.add_argument('--lr', '--learning-rate', default=2e-5, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--nesterov', default=True, type=bool)
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float)
parser.add_argument('--no-cuda', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=42, metavar='S')
parser.add_argument('--print-freq', '-p', default=100, type=int)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--save_name', default='OT_cifar10_imb0.005', type=str)
parser.add_argument('--idx', default='ours', type=str)
parser.add_argument('--ckpt_path', type=str, help='Path to pre-trained model checkpoint')


def main():
    global args, best_prec1
    args = parser.parse_args()
    for arg in vars(args):
        print(f"{arg}={getattr(args, arg)}")

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    kwargs = {'num_workers': 0, 'pin_memory': False}
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    device = torch.device("cuda" if use_cuda else "cpu")

    train_data_meta, train_data, test_dataset = build_dataset(args.dataset, args.num_meta)

    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, **kwargs)

    np.random.seed(42)
    random.seed(42)
    torch.manual_seed(args.seed)

    data_list = {}
    for j in range(args.num_classes):
        data_list[j] = [i for i, label in enumerate(train_loader.dataset.targets) if label == j]

    img_num_list = get_img_num_per_cls(args.dataset, args.imb_factor, args.num_meta*args.num_classes)

    idx_to_del = []
    for cls_idx, img_id_list in data_list.items():
        random.shuffle(img_id_list)
        img_num = img_num_list[int(cls_idx)]
        idx_to_del.extend(img_id_list[img_num:])

    imbalanced_train_dataset = copy.deepcopy(train_data)
    imbalanced_train_dataset.targets = np.delete(train_loader.dataset.targets, idx_to_del, axis=0)
    imbalanced_train_dataset.data = np.delete(train_loader.dataset.data, idx_to_del, axis=0)

    imbalanced_train_loader = DataLoader(new_dataset(imbalanced_train_dataset, train=True),
                                         batch_size=args.batch_size, shuffle=True, **kwargs)
    validation_loader = DataLoader(new_dataset(train_data_meta, train=True),
                                   batch_size=args.num_classes*args.num_meta, shuffle=False, **kwargs)
    test_loader = DataLoader(new_dataset(test_dataset, train=False),
                             batch_size=args.batch_size, shuffle=False, **kwargs)

    best_prec1 = 0

    beta = 0.9999
    effective_num = 1.0 - np.power(beta, img_num_list)
    per_cls_weights = (1.0 - beta) / np.array(effective_num)
    per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(img_num_list)
    per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
    weightsbuffer = torch.tensor([per_cls_weights[cls_i] for cls_i in imbalanced_train_dataset.targets]).to(device)

    eplisons = 0.1
    criterion = SinkhornDistance(eps=eplisons, max_iter=200, reduction=None, dis='cos').to(device)
    criterion_label = SinkhornDistance(eps=eplisons, max_iter=200, reduction=None, dis='euc').to(device)
    criterion_fl = SinkhornDistance_fl(eps=eplisons, max_iter=200, reduction=None).to(device)

    model = build_model(load_pretrain=True, ckpt_path=args.ckpt_path)
    if not model:
        print("Exiting: Failed to build model.")
        return

    optimizer_a = torch.optim.SGD(model.linear.params(), args.lr,
                                  momentum=args.momentum, nesterov=args.nesterov,
                                  weight_decay=args.weight_decay)

    cudnn.benchmark = True
    criterion_classifier = nn.CrossEntropyLoss(reduction='none').to(device)

    for epoch in range(160, args.epochs):

        train_OT(imbalanced_train_loader, validation_loader, weightsbuffer,
                 model, optimizer_a, epoch, criterion_classifier, device,
                 criterion, criterion_label, criterion_fl)

        prec1, _, _ = validate(test_loader, model, device)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if is_best:
            save_checkpoint(args, {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc1': best_prec1,
                'optimizer': optimizer_a.state_dict(),
            }, is_best)

    print('Best accuracy: ', best_prec1)


def train_OT(train_loader, validation_loader, weightsbuffer, model, optimizer, epoch, criterion_classifier, device, criterion, criterion_label, criterion_fl):
    losses = AverageMeter()
    top1 = AverageMeter()
    model.train()

    val_data, val_labels, _ = next(iter(validation_loader))
    val_data = to_var(val_data.to(device), requires_grad=False)
    val_labels = to_var(val_labels.to(device), requires_grad=False).squeeze()

    if args.meta_set == 'prototype':
        val_data_bycls = torch.zeros([args.num_classes, 3, 32, 32]).to(device)
        for i_cls in range(args.num_classes):
            class_samples = val_data[val_labels == i_cls]
            if len(class_samples) > 0:
                 val_data_bycls[i_cls, ::] = class_samples.mean(dim=0)
        val_labels_bycls = torch.tensor([i_l for i_l in range(args.num_classes)]).to(device)
    else: # 'whole'
        val_data_bycls = val_data
        val_labels_bycls = val_labels

    val_labels_onehot = to_categorical(val_labels_bycls).to(device)
    with torch.no_grad():
        feature_val, _ = model(val_data_bycls)

    for i, batch in enumerate(train_loader):
        inputs, labels, ids = tuple(t.to(device) for t in batch)
        labels = labels.squeeze()
        labels_onehot = to_categorical(labels).to(device)
        weights = to_var(weightsbuffer[ids])
        model.eval()
        Attoptimizer = torch.optim.SGD([weights], lr=0.01, momentum=0.9, weight_decay=5e-4)
        with torch.no_grad():
            feature_train, _ = model(inputs)
        probability_train = softmax_normalize(weights)

        if args.cost == 'feature':
            OTloss = criterion(feature_val.detach(), feature_train.detach(), probability_train.squeeze())
        elif args.cost == 'label':
            OTloss = criterion_label(val_labels_onehot.float(),
                                     labels_onehot.float(),
                                     probability_train.squeeze())
        elif args.cost == 'combined':
            OTloss = criterion_fl(feature_val.detach(), feature_train.detach(),
                                  val_labels_onehot.float(),
                                  labels_onehot.float(),
                                  probability_train.squeeze())

        Attoptimizer.zero_grad()
        OTloss.backward()
        Attoptimizer.step()
        weightsbuffer[ids] = weights.data

        model.train()
        optimizer.zero_grad()
        _, logits = model(inputs)
        loss_train = criterion_classifier(logits, labels.long())

        loss = torch.sum(loss_train * weights.detach())
        loss.backward()
        optimizer.step()

        prec_train = accuracy(logits.data, labels, topk=(1,))[0]
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec_train.item(), inputs.size(0))
        if i % args.print_freq == 0 or i == len(train_loader) -1:
            print(f'Epoch: [{epoch}][{i}/{len(train_loader)}]\\t'
                  f'Loss {losses.val:.4f} ({losses.avg:.4f})\\t'
                  f'Prec@1 {top1.val:.3f} ({top1.avg:.3f})')


def validate(val_loader, model, device):
    top1 = AverageMeter()
    model.eval()
    with torch.no_grad():
      for i, batch in enumerate(val_loader):
          input, target, _ = tuple(t.to(device) for t in batch)
          target = target.squeeze().to(device)

          _, output = model(input)
          prec1 = accuracy(output.data, target, topk=(1,))[0]
          top1.update(prec1.item(), input.size(0))
    print(f' * Prec@1 {top1.avg:.3f}')
    return top1.avg, None, None


def build_model(load_pretrain, ckpt_path=None):
    model = ResNet32(args.num_classes)
    if load_pretrain:
        if not ckpt_path or not os.path.exists(ckpt_path):
            print(f"ERROR: Checkpoint file not found at {ckpt_path}")
            return None
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint['state_dict'])
    if torch.cuda.is_available():
        model.cuda()
        torch.backends.cudnn.benchmark = True
    return model

def softmax_normalize(weights, temperature=1.):
    return F.softmax(weights / temperature, dim=0)

class AverageMeter(object):
    def __init__(self): self.reset()
    def reset(self): self.val = 0; self.avg = 0; self.sum = 0; self.count = 0
    def update(self, val, n=1):
        self.val = val; self.sum += val * n; self.count += n; self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def save_checkpoint(args, state, is_best):
    path = 'checkpoint/ours/'
    save_name = args.save_name
    if not os.path.exists(path):
        os.makedirs(path)
    filename = path + save_name + '_ckpt.pth.tar'
    if is_best:
        torch.save(state, filename)

def to_categorical(labels):
    return F.one_hot(labels.long(), num_classes=args.num_classes)

if __name__ == '__main__':
    main()
print("File 'OT_train.py' has been saved.")

Overwriting OT_train.py


# **Imbalance Factor - 100**

In [9]:
# --- Run Full Stage 1 Pre-training (200 epochs) ---
# This will create the pre-trained model checkpoint needed for Stage 2.
# Using Imbalance Factor of 100 (0.01) as an example.

!python pretrain_stage1.py \
--dataset cifar10 \
--imb_factor 0.01 \
--epochs 200

File 'data_utils.py' has been saved.
File 'resnet.py' has been saved.
--> Creating Imbalanced Dataset...
Files already downloaded and verified
--> Building and Training Model for Stage 1...
Epoch 1/200: 100%|██████████████████| 97/97 [00:06<00:00, 15.18it/s, Loss=1.327]
Epoch 2/200: 100%|██████████████████| 97/97 [00:05<00:00, 18.94it/s, Loss=1.284]
Epoch 3/200: 100%|██████████████████| 97/97 [00:05<00:00, 18.74it/s, Loss=1.101]
Epoch 4/200: 100%|██████████████████| 97/97 [00:05<00:00, 18.54it/s, Loss=1.147]
Epoch 5/200: 100%|██████████████████| 97/97 [00:05<00:00, 18.71it/s, Loss=1.385]
Epoch 6/200: 100%|██████████████████| 97/97 [00:05<00:00, 18.94it/s, Loss=0.930]
Epoch 7/200: 100%|██████████████████| 97/97 [00:05<00:00, 19.08it/s, Loss=0.932]
Epoch 8/200: 100%|██████████████████| 97/97 [00:05<00:00, 19.22it/s, Loss=1.026]
Epoch 9/200: 100%|██████████████████| 97/97 [00:05<00:00, 19.30it/s, Loss=0.776]
Epoch 10/200: 100%|█████████████████| 97/97 [00:05<00:00, 19.23it/s, Loss=0.708]


In [10]:
# --- Run Full Stage 2 Re-weighting (40 epochs) ---
# This loads the checkpoint from Stage 1 and performs the final training.

!python OT_train.py \
--dataset cifar10 \
--imb_factor 0.01 \
--epochs 200 \
--cost combined \
--meta_set prototype \
--batch-size 16 \
--gpu 0 \
--save_name OT_cifar10_IF100_full_run \
--ckpt_path checkpoint/ours/pretrain/cifar10_imb0.01_stage1.pth

File 'data_utils.py' has been saved.
File 'resnet.py' has been saved.
File 'Sinkhorn_distance.py' has been saved.
File 'Sinkhorn_distance_fl.py' has been saved.
dataset=cifar10
cost=combined
meta_set=prototype
batch_size=16
num_classes=10
num_meta=10
imb_factor=0.01
epochs=200
lr=2e-05
momentum=0.9
nesterov=True
weight_decay=0.0005
no_cuda=False
seed=42
print_freq=100
gpu=0
save_name=OT_cifar10_IF100_full_run
idx=ours
ckpt_path=checkpoint/ours/pretrain/cifar10_imb0.01_stage1.pth
Files already downloaded and verified
Epoch: [160][0/774]\tLoss 0.0698 (0.0698)\tPrec@1 100.000 (100.000)
Epoch: [160][100/774]\tLoss 6.0032 (3.2475)\tPrec@1 87.500 (90.347)
Epoch: [160][200/774]\tLoss 12.0183 (3.8177)\tPrec@1 81.250 (89.956)
Epoch: [160][300/774]\tLoss 0.9147 (3.7208)\tPrec@1 87.500 (90.407)
Epoch: [160][400/774]\tLoss 1.5014 (3.6063)\tPrec@1 87.500 (90.461)
Epoch: [160][500/774]\tLoss 0.2453 (3.6400)\tPrec@1 93.750 (90.581)
Epoch: [160][600/774]\tLoss 9.7589 (3.5989)\tPrec@1 81.250 (90.381)
E

# **Imbalance Factor - 200**  

In [11]:
# --- Run Stage 1 Pre-training (IF=200) ---

!python pretrain_stage1.py \
--dataset cifar10 \
--imb_factor 0.005 \
--epochs 200

File 'data_utils.py' has been saved.
File 'resnet.py' has been saved.
--> Creating Imbalanced Dataset...
Files already downloaded and verified
--> Building and Training Model for Stage 1...
Epoch 1/200: 100%|██████████████████| 88/88 [00:05<00:00, 14.80it/s, Loss=1.390]
Epoch 2/200: 100%|██████████████████| 88/88 [00:04<00:00, 19.14it/s, Loss=1.242]
Epoch 3/200: 100%|██████████████████| 88/88 [00:04<00:00, 18.98it/s, Loss=1.102]
Epoch 4/200: 100%|██████████████████| 88/88 [00:04<00:00, 18.82it/s, Loss=1.009]
Epoch 5/200: 100%|██████████████████| 88/88 [00:04<00:00, 18.84it/s, Loss=0.747]
Epoch 6/200: 100%|██████████████████| 88/88 [00:04<00:00, 19.06it/s, Loss=0.741]
Epoch 7/200: 100%|██████████████████| 88/88 [00:04<00:00, 19.22it/s, Loss=1.092]
Epoch 8/200: 100%|██████████████████| 88/88 [00:04<00:00, 19.20it/s, Loss=0.777]
Epoch 9/200: 100%|██████████████████| 88/88 [00:04<00:00, 19.51it/s, Loss=0.960]
Epoch 10/200: 100%|█████████████████| 88/88 [00:04<00:00, 19.58it/s, Loss=0.813]


In [12]:
# --- Run Stage 2 Re-weighting (IF=200) ---

!python OT_train.py \
--dataset cifar10 \
--imb_factor 0.005 \
--epochs 200 \
--cost combined \
--meta_set prototype \
--batch-size 16 \
--save_name OT_cifar10_IF200_full_run \
--ckpt_path checkpoint/ours/pretrain/cifar10_imb0.005_stage1.pth

File 'data_utils.py' has been saved.
File 'resnet.py' has been saved.
File 'Sinkhorn_distance.py' has been saved.
File 'Sinkhorn_distance_fl.py' has been saved.
dataset=cifar10
cost=combined
meta_set=prototype
batch_size=16
num_classes=10
num_meta=10
imb_factor=0.005
epochs=200
lr=2e-05
momentum=0.9
nesterov=True
weight_decay=0.0005
no_cuda=False
seed=42
print_freq=100
gpu=0
save_name=OT_cifar10_IF200_full_run
idx=ours
ckpt_path=checkpoint/ours/pretrain/cifar10_imb0.005_stage1.pth
Files already downloaded and verified
Epoch: [160][0/699]\tLoss 0.0903 (0.0903)\tPrec@1 100.000 (100.000)
Epoch: [160][100/699]\tLoss 0.1325 (2.9888)\tPrec@1 100.000 (91.151)
Epoch: [160][200/699]\tLoss 18.7872 (2.8378)\tPrec@1 87.500 (90.796)
Epoch: [160][300/699]\tLoss 0.3522 (3.1290)\tPrec@1 75.000 (90.594)
Epoch: [160][400/699]\tLoss 1.5338 (2.9339)\tPrec@1 87.500 (90.726)
Epoch: [160][500/699]\tLoss 0.6737 (2.8850)\tPrec@1 87.500 (90.793)
Epoch: [160][600/699]\tLoss 1.1679 (2.9136)\tPrec@1 93.750 (90.734

# **Imbalance Factor - 50**

In [13]:
# --- Run Stage 1 Pre-training (IF=50) ---

!python pretrain_stage1.py \
--dataset cifar10 \
--imb_factor 0.02 \
--epochs 200

File 'data_utils.py' has been saved.
File 'resnet.py' has been saved.
--> Creating Imbalanced Dataset...
Files already downloaded and verified
--> Building and Training Model for Stage 1...
Epoch 1/200: 100%|████████████████| 110/110 [00:07<00:00, 15.51it/s, Loss=1.306]
Epoch 2/200: 100%|████████████████| 110/110 [00:05<00:00, 19.12it/s, Loss=1.170]
Epoch 3/200: 100%|████████████████| 110/110 [00:05<00:00, 18.93it/s, Loss=0.648]
Epoch 4/200: 100%|████████████████| 110/110 [00:05<00:00, 18.93it/s, Loss=1.285]
Epoch 5/200: 100%|████████████████| 110/110 [00:05<00:00, 19.22it/s, Loss=1.281]
Epoch 6/200: 100%|████████████████| 110/110 [00:05<00:00, 19.25it/s, Loss=0.880]
Epoch 7/200: 100%|████████████████| 110/110 [00:05<00:00, 19.65it/s, Loss=0.954]
Epoch 8/200: 100%|████████████████| 110/110 [00:05<00:00, 19.79it/s, Loss=0.657]
Epoch 9/200: 100%|████████████████| 110/110 [00:05<00:00, 19.80it/s, Loss=1.076]
Epoch 10/200: 100%|███████████████| 110/110 [00:05<00:00, 19.91it/s, Loss=0.937]


In [14]:
# --- Run Stage 2 Re-weighting (IF=50) ---

!python OT_train.py \
--dataset cifar10 \
--imb_factor 0.02 \
--epochs 200 \
--cost combined \
--meta_set prototype \
--batch-size 16 \
--save_name OT_cifar10_IF50_full_run \
--ckpt_path checkpoint/ours/pretrain/cifar10_imb0.02_stage1.pth

File 'data_utils.py' has been saved.
File 'resnet.py' has been saved.
File 'Sinkhorn_distance.py' has been saved.
File 'Sinkhorn_distance_fl.py' has been saved.
dataset=cifar10
cost=combined
meta_set=prototype
batch_size=16
num_classes=10
num_meta=10
imb_factor=0.02
epochs=200
lr=2e-05
momentum=0.9
nesterov=True
weight_decay=0.0005
no_cuda=False
seed=42
print_freq=100
gpu=0
save_name=OT_cifar10_IF50_full_run
idx=ours
ckpt_path=checkpoint/ours/pretrain/cifar10_imb0.02_stage1.pth
Files already downloaded and verified
Epoch: [160][0/873]\tLoss 0.0892 (0.0892)\tPrec@1 100.000 (100.000)
Epoch: [160][100/873]\tLoss 0.8891 (5.1745)\tPrec@1 100.000 (88.243)
Epoch: [160][200/873]\tLoss 0.3303 (4.5890)\tPrec@1 100.000 (88.464)
Epoch: [160][300/873]\tLoss 0.7260 (4.8903)\tPrec@1 93.750 (88.434)
Epoch: [160][400/873]\tLoss 2.4066 (4.9543)\tPrec@1 87.500 (88.903)
Epoch: [160][500/873]\tLoss 0.7500 (4.7674)\tPrec@1 81.250 (88.760)
Epoch: [160][600/873]\tLoss 4.4312 (4.7421)\tPrec@1 75.000 (88.769)
E

# **Imbalance Factor - 20**

In [15]:
# --- Run Stage 1 Pre-training (IF=20) ---

!python pretrain_stage1.py \
--dataset cifar10 \
--imb_factor 0.05 \
--epochs 200

File 'data_utils.py' has been saved.
File 'resnet.py' has been saved.
--> Creating Imbalanced Dataset...
Files already downloaded and verified
--> Building and Training Model for Stage 1...
Epoch 1/200: 100%|████████████████| 133/133 [00:08<00:00, 16.12it/s, Loss=1.849]
Epoch 2/200: 100%|████████████████| 133/133 [00:06<00:00, 19.13it/s, Loss=1.627]
Epoch 3/200: 100%|████████████████| 133/133 [00:07<00:00, 18.88it/s, Loss=1.328]
Epoch 4/200: 100%|████████████████| 133/133 [00:06<00:00, 19.08it/s, Loss=1.600]
Epoch 5/200: 100%|████████████████| 133/133 [00:06<00:00, 19.26it/s, Loss=1.151]
Epoch 6/200: 100%|████████████████| 133/133 [00:06<00:00, 19.38it/s, Loss=0.842]
Epoch 7/200: 100%|████████████████| 133/133 [00:06<00:00, 19.59it/s, Loss=0.927]
Epoch 8/200: 100%|████████████████| 133/133 [00:06<00:00, 19.80it/s, Loss=0.860]
Epoch 9/200: 100%|████████████████| 133/133 [00:06<00:00, 19.78it/s, Loss=0.967]
Epoch 10/200: 100%|███████████████| 133/133 [00:06<00:00, 19.79it/s, Loss=0.807]


In [16]:
# --- Run Stage 2 Re-weighting (IF=20) ---

!python OT_train.py \
--dataset cifar10 \
--imb_factor 0.05 \
--epochs 200 \
--cost combined \
--meta_set prototype \
--batch-size 16 \
--save_name OT_cifar10_IF20_full_run \
--ckpt_path checkpoint/ours/pretrain/cifar10_imb0.05_stage1.pth

File 'data_utils.py' has been saved.
File 'resnet.py' has been saved.
File 'Sinkhorn_distance.py' has been saved.
File 'Sinkhorn_distance_fl.py' has been saved.
dataset=cifar10
cost=combined
meta_set=prototype
batch_size=16
num_classes=10
num_meta=10
imb_factor=0.05
epochs=200
lr=2e-05
momentum=0.9
nesterov=True
weight_decay=0.0005
no_cuda=False
seed=42
print_freq=100
gpu=0
save_name=OT_cifar10_IF20_full_run
idx=ours
ckpt_path=checkpoint/ours/pretrain/cifar10_imb0.05_stage1.pth
Files already downloaded and verified
Epoch: [160][0/1062]\tLoss 15.8343 (15.8343)\tPrec@1 81.250 (81.250)
Epoch: [160][100/1062]\tLoss 6.3510 (5.0155)\tPrec@1 75.000 (89.418)
Epoch: [160][200/1062]\tLoss 1.1517 (5.1683)\tPrec@1 93.750 (89.614)
Epoch: [160][300/1062]\tLoss 13.7146 (5.2761)\tPrec@1 75.000 (89.514)
Epoch: [160][400/1062]\tLoss 4.4599 (5.4860)\tPrec@1 93.750 (89.308)
Epoch: [160][500/1062]\tLoss 25.8930 (5.2894)\tPrec@1 81.250 (89.371)
Epoch: [160][600/1062]\tLoss 9.4774 (5.2133)\tPrec@1 75.000 (89

# **Plot**

In [17]:
%%writefile replicate_figure_1.py
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import os

# Import from your existing files
from data_utils import build_dataset, new_dataset
from resnet import ResNet32
from Sinkhorn_distance import SinkhornDistance
from Sinkhorn_distance_fl import SinkhornDistance as SinkhornDistance_fl

# Define the to_categorical helper function inside this script
def to_categorical(labels, num_classes):
    return F.one_hot(labels.long(), num_classes=num_classes)

def get_args():
    parser = argparse.ArgumentParser(description='Figure 1 Replication')
    # Make sure this checkpoint from your previous CIFAR-10 run exists
    parser.add_argument('--ckpt_path', type=str,
                        default='checkpoint/ours/OT_cifar10_IF100_full_run_ckpt.pth.tar',
                        help='Path to a pre-trained model checkpoint.')
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--num_meta', type=int, default=10)
    parser.add_argument('--gpu', default=0, type=int)
    return parser.parse_args(args=[]) # Use args=[] to prevent conflicts in Colab

def build_model(ckpt_path, num_classes):
    model = ResNet32(num_classes)
    if not os.path.exists(ckpt_path):
        print(f"ERROR: Checkpoint file not found at {ckpt_path}. Please ensure Stage 2 training was completed for CIFAR-10.")
        return None
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    if torch.cuda.is_available():
        model.cuda()
    model.eval()
    return model

def main():
    args = get_args()
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")

    # 1. Load pre-trained model
    print("--> Loading pre-trained model...")
    model = build_model(args.ckpt_path, args.num_classes)
    if model is None: return

    # 2. Create the specific data splits as per the paper
    print("--> Preparing data...")
    meta_dataset, train_dataset, _ = build_dataset('cifar10', args.num_meta)

    # Create meta-set prototypes
    meta_loader = torch.utils.data.DataLoader(new_dataset(meta_dataset, train=False), batch_size=100, shuffle=False)
    meta_images, meta_labels, _ = next(iter(meta_loader))

    proto_images = torch.zeros([args.num_classes, 3, 32, 32])
    for i_cls in range(args.num_classes):
        proto_images[i_cls] = meta_images[meta_labels.squeeze() == i_cls].mean(dim=0)
    proto_images = proto_images.to(device)
    proto_labels_onehot = to_categorical(torch.arange(args.num_classes), args.num_classes).to(device)

    # Create the 55-sample imbalanced batch
    batch_images_list, batch_labels_list = [], []
    for i_cls in range(args.num_classes):
        num_samples = 10 - i_cls
        class_indices = [j for j, label in enumerate(train_dataset.targets) if label == i_cls]

        for idx in class_indices[:num_samples]:
            img, label, _ = new_dataset(train_dataset, train=False)[idx]
            batch_images_list.append(img)
            batch_labels_list.append(label.item())

    batch_images = torch.stack(batch_images_list).to(device)
    batch_labels = torch.tensor(batch_labels_list).to(device)
    batch_labels_onehot = to_categorical(batch_labels, args.num_classes).to(device)

    # 3. Get model features
    with torch.no_grad():
        proto_features, _ = model(proto_images)
        batch_features, _ = model(batch_images)

    # 4. Calculate cost matrices and learn weights
    cost_matrices = {}
    learned_weights = {}

    criterion_cos = SinkhornDistance(eps=0.1, max_iter=200, dis='cos').to(device)
    criterion_euc = SinkhornDistance(eps=0.1, max_iter=200, dis='euc').to(device)
    criterion_comb = SinkhornDistance_fl(eps=0.1, max_iter=200).to(device)

    print("--> Calculating matrices and learning weights...")
    for cost_type in ['feature', 'label', 'combined']:
        
        weights = torch.ones(55, requires_grad=True, device=device)
       
        
        Attoptimizer = torch.optim.SGD([weights], lr=0.1, momentum=0.9)

        for _ in range(20): # Small number of iterations to learn weights
            probability_train = F.softmax(weights, dim=0)
            if cost_type == 'feature':
                cost_matrix = 1 - F.cosine_similarity(batch_features.unsqueeze(1), proto_features.unsqueeze(0), dim=-1)
                OTloss = criterion_cos(proto_features.detach(), batch_features.detach(), probability_train.squeeze())
            elif cost_type == 'label':
                # Use torch.cdist for a clean euclidean distance matrix
                cost_matrix = torch.cdist(batch_labels_onehot.float(), proto_labels_onehot.float(), p=2.0)
                OTloss = criterion_euc(proto_labels_onehot.float(), batch_labels_onehot.float(), probability_train.squeeze())
            else: # combined
                cost_matrix_feat = 1 - F.cosine_similarity(batch_features.unsqueeze(1), proto_features.unsqueeze(0), dim=-1)
                cost_matrix_lab = torch.cdist(batch_labels_onehot.float(), proto_labels_onehot.float(), p=2.0)
                cost_matrix = 0.5 * cost_matrix_feat + 0.5 * cost_matrix_lab
                OTloss = criterion_comb(proto_features.detach(), batch_features.detach(),
                                        proto_labels_onehot.float(), batch_labels_onehot.float(), probability_train.squeeze())
            Attoptimizer.zero_grad()
            OTloss.backward()
            Attoptimizer.step()

        cost_matrices[cost_type] = cost_matrix.detach().cpu().numpy()
        learned_weights[cost_type] = F.softmax(weights.detach(), dim=0).cpu().numpy()

    # 5. Plotting
    print("--> Generating plot...")
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.tight_layout(pad=6.0)

    plot_order = ['feature', 'label', 'combined']
    titles = ['Feature-aware Cost Matrix', 'Label-aware Cost Matrix', 'Combined Cost Matrix']

    # Sort samples by class for clearer visualization
    sorted_indices = np.argsort(batch_labels.cpu().numpy())
    sorted_labels = batch_labels.cpu().numpy()[sorted_indices]

    for i, cost_type in enumerate(plot_order):
        # Sort the rows of the cost matrix to group by class
        sorted_cost_matrix = cost_matrices[cost_type][sorted_indices, :]
        
        sns.heatmap(sorted_cost_matrix, ax=axes[0, i], cmap='viridis', cbar=True)
        axes[0, i].set_title(titles[i], fontsize=14)
        axes[0, i].set_xlabel("Meta-Set Prototypes (Class 0-9)", fontsize=12)
        axes[0, i].set_ylabel("Imbalanced Batch Samples (Sorted by Class)", fontsize=12)

        # Sort the weights for plotting
        sorted_weights = learned_weights[cost_type][sorted_indices]
        
        scatter = axes[1, i].scatter(range(55), sorted_weights, c=sorted_labels, cmap='tab10', alpha=0.8)
        axes[1, i].set_title(f"Learned Weight Vector ({cost_type.capitalize()} Cost)", fontsize=14)
        axes[1, i].set_xlabel("Sample Index (Sorted by Class)", fontsize=12)
        axes[1, i].set_ylabel("Learned Weight", fontsize=12)
        
        # Create a legend
        handles, _ = scatter.legend_elements()
        legend_labels = [f'Class {c}' for c in range(args.num_classes)]
        axes[1, i].legend(handles, legend_labels, title="Classes")


    plt.savefig("replicated_figure_1.png")
    print("--> Plot saved as replicated_figure_1.png")


if __name__ == '__main__':
    main()

Overwriting replicate_figure_1.py
