In [2]:
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from tqdm import tqdm 
import time
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision.models.resnet import ResNet18_Weights
import torchviz
import pickle
import random

seed = 42
torch.manual_seed(seed)

<torch._C.Generator at 0x1101292f0>

In [3]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

### Load Resent Model 

In [25]:
def get_resent18_model(num_classes=10):
    model = torchvision.models.resnet18(weights=None)  # Use 'weights' for pretrained models
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    return model

def get_resent101_model(num_classes=10):
    model = torchvision.models.resnet101(weights=None)  # Use 'weights' for pretrained models
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    return model

In [26]:
model = get_resent101_model()
print(type(model))

<class 'torchvision.models.resnet.ResNet'>


In [8]:
# Freeze pre-trained layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze some layers for fine-tuning
for param in model.layer4.parameters():
    param.requires_grad = True

In [6]:
# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [7]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
print(train_dataset[0][0].shape)

torch.Size([3, 32, 32])


## LeNet Model Definition

In [16]:
class LeNet(nn.Module):    
    def __init__(self, out_classes=10):
        super(LeNet, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(400,120),  #in_features = 16 x5x5 
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84, out_classes),
            # nn.Softmax(dim=1)
        )
        
    def forward(self,x): 
        a1=self.feature_extractor(x)
        # print(a1.shape)
        a1 = torch.flatten(a1,1)
        a2=self.classifier(a1)
        return a2

# Baseline Model Training

### Basic Train Loop

In [18]:
seed = 42
torch.manual_seed(seed)

# Define the Model
# model = get_cifar10_model()
model = LeNet()
model = model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Train the model
epochs = 20
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop
    for images, labels in tqdm(train_dataloader):

        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Evaluate on test set
    # correct = 0
    # total = 0
    # with torch.no_grad():
    #     for images, labels in tqdm(test_dataloader):
    #         images = images.to(device)
    #         labels = labels.to(device)

    #         outputs = model(images)
    #         predictions = torch.argmax(outputs, dim=1)
    #         correct += (predictions == labels).sum().item()
    #         total += labels.size(0)

    # accuracy = correct / total
    # print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")


    
print("--- %s seconds ---" % (time.time() - start_time))

# Evaluate on test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Accuracy: {accuracy:.4f}")
# print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")


100%|██████████| 782/782 [00:18<00:00, 42.66it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.41it/s] 
  5%|▌         | 1/20 [00:30<09:48, 30.99s/it]

Epoch: [1/20], Accuracy: 0.4452


100%|██████████| 782/782 [00:18<00:00, 42.33it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.47it/s] 
 10%|█         | 2/20 [01:02<09:18, 31.04s/it]

Epoch: [2/20], Accuracy: 0.4982


100%|██████████| 782/782 [00:18<00:00, 42.77it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.31it/s] 
 15%|█▌        | 3/20 [01:33<08:47, 31.04s/it]

Epoch: [3/20], Accuracy: 0.5362


100%|██████████| 782/782 [00:18<00:00, 42.22it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.40it/s] 
 20%|██        | 4/20 [02:04<08:17, 31.10s/it]

Epoch: [4/20], Accuracy: 0.5314


100%|██████████| 782/782 [00:18<00:00, 42.93it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.46it/s] 
 25%|██▌       | 5/20 [02:35<07:45, 31.00s/it]

Epoch: [5/20], Accuracy: 0.5606


100%|██████████| 782/782 [00:18<00:00, 42.88it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.41it/s] 
 30%|███       | 6/20 [03:06<07:13, 30.97s/it]

Epoch: [6/20], Accuracy: 0.5699


100%|██████████| 782/782 [00:18<00:00, 42.97it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.43it/s] 
 35%|███▌      | 7/20 [03:36<06:42, 30.93s/it]

Epoch: [7/20], Accuracy: 0.5707


100%|██████████| 782/782 [00:18<00:00, 42.86it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.48it/s]
 40%|████      | 8/20 [04:07<06:10, 30.90s/it]

Epoch: [8/20], Accuracy: 0.5827


100%|██████████| 782/782 [00:18<00:00, 42.91it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.46it/s]
 45%|████▌     | 9/20 [04:38<05:39, 30.88s/it]

Epoch: [9/20], Accuracy: 0.5819


100%|██████████| 782/782 [00:18<00:00, 42.70it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.46it/s]
 50%|█████     | 10/20 [05:09<05:08, 30.89s/it]

Epoch: [10/20], Accuracy: 0.5955


100%|██████████| 782/782 [00:18<00:00, 42.70it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.47it/s] 
 55%|█████▌    | 11/20 [05:40<04:38, 30.90s/it]

Epoch: [11/20], Accuracy: 0.6033


100%|██████████| 782/782 [00:18<00:00, 42.85it/s] 
100%|██████████| 157/157 [00:59<00:00,  2.65it/s] 
 60%|██████    | 12/20 [06:57<06:00, 45.05s/it]

Epoch: [12/20], Accuracy: 0.5998


100%|██████████| 782/782 [00:18<00:00, 41.85it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.19it/s] 
 65%|██████▌   | 13/20 [07:29<04:46, 40.96s/it]

Epoch: [13/20], Accuracy: 0.6070


100%|██████████| 782/782 [00:18<00:00, 41.98it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.34it/s] 
 70%|███████   | 14/20 [08:00<03:48, 38.06s/it]

Epoch: [14/20], Accuracy: 0.6171


100%|██████████| 782/782 [00:18<00:00, 42.54it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.33it/s] 
 75%|███████▌  | 15/20 [08:31<02:59, 35.97s/it]

Epoch: [15/20], Accuracy: 0.6298


100%|██████████| 782/782 [00:19<00:00, 40.43it/s]
100%|██████████| 157/157 [00:12<00:00, 12.48it/s]
 80%|████████  | 16/20 [09:03<02:19, 34.75s/it]

Epoch: [16/20], Accuracy: 0.6320


100%|██████████| 782/782 [00:18<00:00, 42.37it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.44it/s]
 85%|████████▌ | 17/20 [09:34<01:40, 33.65s/it]

Epoch: [17/20], Accuracy: 0.6250


100%|██████████| 782/782 [00:18<00:00, 42.83it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.45it/s] 
 90%|█████████ | 18/20 [10:05<01:05, 32.82s/it]

Epoch: [18/20], Accuracy: 0.6331


100%|██████████| 782/782 [00:18<00:00, 42.74it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.46it/s] 
 95%|█████████▌| 19/20 [10:36<00:32, 32.25s/it]

Epoch: [19/20], Accuracy: 0.6495


100%|██████████| 782/782 [00:18<00:00, 42.74it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.43it/s] 
100%|██████████| 20/20 [11:07<00:00, 33.38s/it]


Epoch: [20/20], Accuracy: 0.6347
--- 667.5797848701477 seconds ---


100%|██████████| 157/157 [00:12<00:00, 12.44it/s] 

Accuracy: 0.6347





# Reweight Model Training

In [9]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

split_ratio = 0.95 # 80% for test, 20% for validation
n_samples = len(test_dataset)
n_test = int(n_samples * split_ratio)
n_val = n_samples - n_test
test_dataset, val_dataset = random_split(test_dataset, [n_test, n_val])


# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


### Random Sampler for Sampling Validation Data

In [10]:
class RandomSubsetSampler(torch.utils.data.Sampler):
    def __init__(self, dataset, subset_size):
        self.dataset = dataset
        self.subset_size = subset_size

    def __iter__(self):
        indices = random.sample(range(len(self.dataset)), self.subset_size)
        return iter(indices)

    def __len__(self):
        return self.subset_size

subset_sampler = RandomSubsetSampler(val_dataset, 64)
subset_dataloader = DataLoader(val_dataset, sampler=subset_sampler)

### Meta Baseline Trainloop

In [18]:
seed = 42
torch.manual_seed(seed)

# Define the Model
# model = get_cifar10_model()
model = LeNet()
model = model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

# Train the model
epochs = 20
lr = 0.001
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop
    for images, labels in tqdm(train_dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        # meta_net = get_cifar10_model()
        meta_net = LeNet()
        meta_net.load_state_dict(model.state_dict())

        meta_net = meta_net.to(device)

        optimizer_meta = torch.optim.Adam(meta_net.parameters())

        meta_net.train()
        
        y_f_hat = meta_net(images)
        cost = loss_fn_meta(y_f_hat, labels)
        eps = torch.zeros(cost.size(), requires_grad=True).to(device)
        l_f_meta = torch.sum(cost*eps)

        # meta_net.zero_grad()
        optimizer_meta.zero_grad()
        eps.retain_grad()
        l_f_meta.backward()
        optimizer_meta.step()

        meta_net.eval()

        # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
        # meta_net.update_params(lr, source_params=grads)
        
        val_images, val_labels = next(iter(subset_dataloader))
        # val_images, val_labels = next(iter(val_dataloader))
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)

        y_g_hat = meta_net(val_images)
        l_g_meta = loss_fn(y_g_hat, val_labels)

        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
        # print(grad_eps)

        with torch.no_grad():
            optimizer_meta.zero_grad()
            l_g_meta.backward()
            grad_eps = eps.grad
        
        # print(grad_eps)
        w_tilde = torch.clamp(grad_eps,min=0)
        # w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde
        
        # print(w)
        # break
        # Forward Pass
        outputs = model(images)
        loss = loss_fn_meta(outputs, labels)
        loss = torch.sum(loss*w)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
print("--- %s seconds ---" % (time.time() - start_time))

# Evaluate on test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Accuracy: {accuracy:.4f}")



100%|██████████| 782/782 [00:32<00:00, 23.77it/s]
100%|██████████| 782/782 [00:33<00:00, 23.69it/s]
100%|██████████| 782/782 [00:32<00:00, 23.71it/s]
100%|██████████| 782/782 [00:32<00:00, 23.75it/s]
100%|██████████| 782/782 [00:33<00:00, 23.12it/s]
100%|██████████| 782/782 [00:32<00:00, 24.13it/s]
100%|██████████| 782/782 [00:33<00:00, 23.55it/s]
100%|██████████| 782/782 [00:32<00:00, 23.91it/s]
100%|██████████| 782/782 [00:33<00:00, 23.65it/s]
100%|██████████| 782/782 [00:34<00:00, 22.95it/s]
100%|██████████| 782/782 [00:34<00:00, 22.58it/s]
100%|██████████| 782/782 [00:33<00:00, 23.10it/s]
100%|██████████| 782/782 [00:32<00:00, 24.30it/s]
100%|██████████| 782/782 [00:32<00:00, 23.98it/s]
100%|██████████| 782/782 [00:34<00:00, 22.82it/s]
100%|██████████| 782/782 [00:33<00:00, 23.13it/s]
100%|██████████| 782/782 [00:35<00:00, 22.14it/s]
100%|██████████| 782/782 [00:33<00:00, 23.09it/s]
100%|██████████| 782/782 [00:34<00:00, 22.68it/s]
100%|██████████| 782/782 [00:34<00:00, 22.50it/s]


--- 670.9591670036316 seconds ---


100%|██████████| 149/149 [00:12<00:00, 11.77it/s]

Accuracy: 0.5983





# Milo Setup

### Load Data

In [21]:
num_classes = 10
class_data = []
subset_fraction = 0.3
for i in range(num_classes):
    with open(f"milo-base/class-data-{subset_fraction}/class_{i}.pkl", "rb") as f:
        S = pickle.load(f)
        class_data.append(S)

In [22]:
num_sets = len(class_data[0])
data = []
for i in range(num_sets):
    S = []
    for j in range(num_classes):
        S.extend(class_data[j][i])
    data.append(S)

In [26]:
print(len(data[0]))

15000


### Define Dataloader

In [27]:
class SubDataset(Dataset):
    def __init__(self, indices, dataset):
        self.indices = indices
        self.dataset = dataset

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        data_point = self.dataset[index]
        return data_point

## Milo Training Loop

In [28]:
# Define Model
model = LeNet()
model = model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Train the model
epochs = 20
model.train()
R = 1
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop

    if epoch%R==0:
        sub_dataset = SubDataset(indices=data[epoch//R], dataset=train_dataset)
        subset_train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True)
        
    for images, labels in tqdm(subset_train_dataloader):

        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")


print("--- %s seconds ---" % (time.time() - start_time))

# Evaluate on test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Accuracy: {accuracy:.4f}")


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 235/235 [00:03<00:00, 72.58it/s]
100%|██████████| 157/157 [00:13<00:00, 12.03it/s] 
  5%|▌         | 1/20 [00:16<05:09, 16.29s/it]

Epoch: [1/20], Accuracy: 0.3684


100%|██████████| 235/235 [00:02<00:00, 79.35it/s]
100%|██████████| 157/157 [00:12<00:00, 12.47it/s] 
 10%|█         | 2/20 [00:31<04:45, 15.86s/it]

Epoch: [2/20], Accuracy: 0.4113


100%|██████████| 235/235 [00:03<00:00, 78.19it/s]
100%|██████████| 157/157 [00:12<00:00, 12.45it/s] 
 15%|█▌        | 3/20 [00:47<04:27, 15.76s/it]

Epoch: [3/20], Accuracy: 0.4295


100%|██████████| 235/235 [00:02<00:00, 79.78it/s]
100%|██████████| 157/157 [00:12<00:00, 12.48it/s] 
 20%|██        | 4/20 [01:03<04:10, 15.67s/it]

Epoch: [4/20], Accuracy: 0.4765


100%|██████████| 235/235 [00:02<00:00, 80.32it/s]
100%|██████████| 157/157 [00:12<00:00, 12.45it/s] 
 25%|██▌       | 5/20 [01:18<03:54, 15.62s/it]

Epoch: [5/20], Accuracy: 0.4606


100%|██████████| 235/235 [00:02<00:00, 80.12it/s]
100%|██████████| 157/157 [00:12<00:00, 12.46it/s] 
 30%|███       | 6/20 [01:34<03:38, 15.59s/it]

Epoch: [6/20], Accuracy: 0.4852


100%|██████████| 235/235 [00:02<00:00, 80.17it/s]
100%|██████████| 157/157 [00:12<00:00, 12.45it/s] 
 35%|███▌      | 7/20 [01:49<03:22, 15.58s/it]

Epoch: [7/20], Accuracy: 0.4995


100%|██████████| 235/235 [00:03<00:00, 78.04it/s]
100%|██████████| 157/157 [00:12<00:00, 12.40it/s] 
 40%|████      | 8/20 [02:05<03:07, 15.61s/it]

Epoch: [8/20], Accuracy: 0.5013


100%|██████████| 235/235 [00:03<00:00, 77.66it/s]
100%|██████████| 157/157 [00:12<00:00, 12.41it/s] 
 45%|████▌     | 9/20 [02:20<02:51, 15.63s/it]

Epoch: [9/20], Accuracy: 0.5195


100%|██████████| 235/235 [00:02<00:00, 78.64it/s]
100%|██████████| 157/157 [00:12<00:00, 12.41it/s] 
 50%|█████     | 10/20 [02:36<02:36, 15.64s/it]

Epoch: [10/20], Accuracy: 0.5195


100%|██████████| 235/235 [00:02<00:00, 78.73it/s]
100%|██████████| 157/157 [00:12<00:00, 12.31it/s] 
 55%|█████▌    | 11/20 [02:52<02:21, 15.67s/it]

Epoch: [11/20], Accuracy: 0.5297


100%|██████████| 235/235 [00:03<00:00, 77.28it/s]
100%|██████████| 157/157 [00:12<00:00, 12.41it/s]
 60%|██████    | 12/20 [03:08<02:05, 15.68s/it]

Epoch: [12/20], Accuracy: 0.5343


100%|██████████| 235/235 [00:03<00:00, 77.44it/s]
100%|██████████| 157/157 [00:12<00:00, 12.21it/s] 
 65%|██████▌   | 13/20 [03:23<01:50, 15.75s/it]

Epoch: [13/20], Accuracy: 0.5236


100%|██████████| 235/235 [00:02<00:00, 79.26it/s]
100%|██████████| 157/157 [00:12<00:00, 12.43it/s] 
 70%|███████   | 14/20 [03:39<01:34, 15.70s/it]

Epoch: [14/20], Accuracy: 0.5403


100%|██████████| 235/235 [00:02<00:00, 79.52it/s]
100%|██████████| 157/157 [00:12<00:00, 12.30it/s] 
 75%|███████▌  | 15/20 [03:55<01:18, 15.71s/it]

Epoch: [15/20], Accuracy: 0.5397


100%|██████████| 235/235 [00:02<00:00, 79.95it/s]
100%|██████████| 157/157 [00:12<00:00, 12.43it/s] 
 80%|████████  | 16/20 [04:10<01:02, 15.67s/it]

Epoch: [16/20], Accuracy: 0.5465


100%|██████████| 235/235 [00:03<00:00, 76.80it/s]
100%|██████████| 157/157 [00:12<00:00, 12.33it/s] 
 85%|████████▌ | 17/20 [04:26<00:47, 15.71s/it]

Epoch: [17/20], Accuracy: 0.5475


100%|██████████| 235/235 [00:02<00:00, 81.31it/s]
100%|██████████| 157/157 [00:12<00:00, 12.16it/s] 
 90%|█████████ | 18/20 [04:42<00:31, 15.74s/it]

Epoch: [18/20], Accuracy: 0.5437


100%|██████████| 235/235 [00:02<00:00, 80.23it/s]
100%|██████████| 157/157 [00:12<00:00, 12.47it/s] 
 95%|█████████▌| 19/20 [04:58<00:15, 15.68s/it]

Epoch: [19/20], Accuracy: 0.5591


100%|██████████| 235/235 [00:02<00:00, 79.44it/s]
100%|██████████| 157/157 [00:12<00:00, 12.44it/s] 
100%|██████████| 20/20 [05:13<00:00, 15.68s/it]


Epoch: [20/20], Accuracy: 0.5570
--- 313.63228011131287 seconds ---


100%|██████████| 157/157 [00:12<00:00, 12.42it/s] 

Accuracy: 0.5570





In [27]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

split_ratio = 0.90
n_samples = len(test_dataset)
n_test = int(n_samples * split_ratio)
n_val = n_samples - n_test
test_dataset, val_dataset = random_split(test_dataset, [n_test, n_val])


# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [29]:
subset_sampler = RandomSubsetSampler(val_dataset, 64)
subset_dataloader = DataLoader(val_dataset, sampler=subset_sampler)

### Meta-Milo Training loop

In [30]:
# Define Model
model = LeNet()
model = model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

# Train the model
epochs = 20
# lr = 0.001
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop

    if epoch%R==0:
        sub_dataset = SubDataset(indices=data[epoch//R], dataset=train_dataset)
        # train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True, num_workers=2)
        train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True)
    
    for images, labels in tqdm(train_dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        # meta_net = get_cifar10_model()
        meta_net = LeNet()
        meta_net.load_state_dict(model.state_dict())

        meta_net = meta_net.to(device)

        optimizer_meta = torch.optim.Adam(meta_net.parameters())

        meta_net.train()
        
        y_f_hat = meta_net(images)
        cost = loss_fn_meta(y_f_hat, labels)
        eps = torch.zeros(cost.size(), requires_grad=True).to(device)
        l_f_meta = torch.sum(cost*eps)

        # meta_net.zero_grad()
        optimizer_meta.zero_grad()
        eps.retain_grad()
        l_f_meta.backward()
        optimizer_meta.step()

        meta_net.eval()

        # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
        # meta_net.update_params(lr, source_params=grads)
        
        val_images, val_labels = next(iter(subset_dataloader))
        # val_images, val_labels = next(iter(val_dataloader))
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)

        y_g_hat = meta_net(val_images)
        l_g_meta = loss_fn(y_g_hat, val_labels)

        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
        # print(grad_eps)

        with torch.no_grad():
            optimizer_meta.zero_grad()
            l_g_meta.backward()
            grad_eps = eps.grad
        
        # print(grad_eps)
        w_tilde = torch.clamp(grad_eps,min=0)
        # w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde
        
        # print(w)
        # break
        # Forward Pass
        outputs = model(images)
        loss = loss_fn_meta(outputs, labels)
        loss = torch.sum(loss*w)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

print("--- %s seconds ---" % (time.time() - start_time))

# Evaluate on test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Accuracy: {accuracy:.4f}")

  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 235/235 [00:07<00:00, 32.03it/s]
100%|██████████| 235/235 [00:06<00:00, 33.70it/s]
100%|██████████| 235/235 [00:06<00:00, 33.63it/s]
100%|██████████| 235/235 [00:07<00:00, 33.31it/s]
100%|██████████| 235/235 [00:06<00:00, 33.75it/s]
100%|██████████| 235/235 [00:06<00:00, 33.81it/s]
100%|██████████| 235/235 [00:06<00:00, 33.72it/s]
100%|██████████| 235/235 [00:06<00:00, 33.70it/s]
100%|██████████| 235/235 [00:06<00:00, 33.84it/s]
100%|██████████| 235/235 [00:06<00:00, 33.78it/s]
100%|██████████| 235/235 [00:07<00:00, 33.13it/s]
100%|██████████| 235/235 [00:07<00:00, 33.40it/s]
100%|██████████| 235/235 [00:06<00:00, 33.76it/s]
100%|██████████| 235/235 [00:07<00:00, 33.48it/s]
100%|██████████| 235/235 [00:06<00:00, 33.64it/s]
100%|██████████| 235/235 [00:07<00:00, 32.41it/s]
100%|██████████| 235/235 [00:07<00:00, 32.10it/s]
100%|██████████| 235/235 [00:07<00:00, 31.29it/s]
100%|██████████| 235/235 [00:07<00:00, 30.89it/s]
100%|██████████| 235/235 [00:07<00:00, 31.16it/s]


--- 142.4714057445526 seconds ---


100%|██████████| 141/141 [00:00<00:00, 141.15it/s]

Accuracy: 0.5166





# Random Subset Meta 