In [1]:
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from tqdm import tqdm 
import time
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision.models.resnet import ResNet18_Weights
import torchviz
import pickle
import random

seed = 42
torch.manual_seed(seed)

In [13]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

### Load Resent Model 

In [3]:
def get_cifar10_model():
    model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)  # Use 'weights' for pretrained models
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)
    return model

In [4]:
model = get_cifar10_model()
print(type(model))

<class 'torchvision.models.resnet.ResNet'>


In [8]:
# Freeze pre-trained layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze some layers for fine-tuning
for param in model.layer4.parameters():
    param.requires_grad = True

In [8]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

In [3]:
# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [7]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


## LeNet Model Definition

In [4]:
class LeNet(nn.Module):    
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(400,120),  #in_features = 16 x5x5 
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10),
            nn.Softmax(dim=1)

        )
        
    def forward(self,x): 
        a1=self.feature_extractor(x)
        # print(a1.shape)
        a1 = torch.flatten(a1,1)
        a2=self.classifier(a1)
        return a2

# Baseline Model Training

### Basic Train Loop

In [9]:
# Define the Model
# model = get_cifar10_model()
model = LeNet()
model = model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Train the model
epochs = 20
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop
    for images, labels in tqdm(train_dataloader):

        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")

print("--- %s seconds ---" % (time.time() - start_time))

100%|██████████| 782/782 [00:24<00:00, 32.11it/s] 
100%|██████████| 157/157 [00:13<00:00, 11.22it/s] 
  5%|▌         | 1/20 [00:38<12:08, 38.35s/it]

Epoch: [1/20], Accuracy: 0.3563


100%|██████████| 782/782 [00:20<00:00, 37.63it/s] 
100%|██████████| 157/157 [00:13<00:00, 11.45it/s] 
 10%|█         | 2/20 [01:12<10:49, 36.09s/it]

Epoch: [2/20], Accuracy: 0.4079


100%|██████████| 782/782 [00:20<00:00, 38.52it/s]
100%|██████████| 157/157 [00:13<00:00, 11.52it/s] 
 15%|█▌        | 3/20 [01:46<09:56, 35.10s/it]

Epoch: [3/20], Accuracy: 0.4171


100%|██████████| 782/782 [00:20<00:00, 38.89it/s] 
100%|██████████| 157/157 [00:13<00:00, 11.66it/s] 
 20%|██        | 4/20 [02:20<09:12, 34.50s/it]

Epoch: [4/20], Accuracy: 0.4544


100%|██████████| 782/782 [00:21<00:00, 35.74it/s] 
100%|██████████| 157/157 [00:13<00:00, 11.31it/s] 
 25%|██▌       | 5/20 [02:56<08:44, 34.96s/it]

Epoch: [5/20], Accuracy: 0.4802


100%|██████████| 782/782 [00:20<00:00, 38.42it/s] 
100%|██████████| 157/157 [00:13<00:00, 11.23it/s] 
 30%|███       | 6/20 [03:30<08:06, 34.75s/it]

Epoch: [6/20], Accuracy: 0.4861


100%|██████████| 782/782 [00:18<00:00, 41.48it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.33it/s] 
 35%|███▌      | 7/20 [04:02<07:18, 33.72s/it]

Epoch: [7/20], Accuracy: 0.4993


100%|██████████| 782/782 [00:17<00:00, 43.64it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.49it/s]
 40%|████      | 8/20 [04:32<06:32, 32.70s/it]

Epoch: [8/20], Accuracy: 0.5031


100%|██████████| 782/782 [00:18<00:00, 43.18it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.20it/s]
 45%|████▌     | 9/20 [05:03<05:53, 32.17s/it]

Epoch: [9/20], Accuracy: 0.5115


100%|██████████| 782/782 [00:17<00:00, 43.72it/s] 
100%|██████████| 157/157 [00:13<00:00, 12.00it/s] 
 50%|█████     | 10/20 [05:34<05:18, 31.81s/it]

Epoch: [10/20], Accuracy: 0.5257


100%|██████████| 782/782 [00:17<00:00, 44.35it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.40it/s] 
 55%|█████▌    | 11/20 [06:04<04:42, 31.36s/it]

Epoch: [11/20], Accuracy: 0.5230


100%|██████████| 782/782 [00:17<00:00, 43.74it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.21it/s]
 60%|██████    | 12/20 [06:35<04:09, 31.17s/it]

Epoch: [12/20], Accuracy: 0.5164


100%|██████████| 782/782 [00:17<00:00, 43.81it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.45it/s]
 65%|██████▌   | 13/20 [07:06<03:36, 30.96s/it]

Epoch: [13/20], Accuracy: 0.5334


100%|██████████| 782/782 [00:17<00:00, 44.39it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.41it/s]
 70%|███████   | 14/20 [07:36<03:04, 30.76s/it]

Epoch: [14/20], Accuracy: 0.5277


100%|██████████| 782/782 [00:17<00:00, 43.86it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.34it/s]
 75%|███████▌  | 15/20 [08:07<02:33, 30.70s/it]

Epoch: [15/20], Accuracy: 0.5364


100%|██████████| 782/782 [00:18<00:00, 43.37it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.23it/s] 
 80%|████████  | 16/20 [08:37<02:03, 30.76s/it]

Epoch: [16/20], Accuracy: 0.5463


100%|██████████| 782/782 [00:17<00:00, 44.64it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.42it/s]
 85%|████████▌ | 17/20 [09:08<01:31, 30.58s/it]

Epoch: [17/20], Accuracy: 0.5497


100%|██████████| 782/782 [00:17<00:00, 44.33it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.53it/s]
 90%|█████████ | 18/20 [09:38<01:00, 30.46s/it]

Epoch: [18/20], Accuracy: 0.5386


100%|██████████| 782/782 [00:17<00:00, 44.56it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.45it/s]
 95%|█████████▌| 19/20 [10:08<00:30, 30.38s/it]

Epoch: [19/20], Accuracy: 0.5402


100%|██████████| 782/782 [00:18<00:00, 43.27it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.51it/s]
100%|██████████| 20/20 [10:39<00:00, 31.95s/it]

Epoch: [20/20], Accuracy: 0.5476
--- 639.0781149864197 seconds ---





# Reweight Model Training

In [6]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

split_ratio = 0.95 # 80% for test, 20% for validation
n_samples = len(test_dataset)
n_test = int(n_samples * split_ratio)
n_val = n_samples - n_test
test_dataset, val_dataset = random_split(test_dataset, [n_test, n_val])


# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


### Random Sampler for Sampling Validation Data

In [23]:
class RandomSubsetSampler(torch.utils.data.Sampler):
    def __init__(self, dataset, subset_size):
        self.dataset = dataset
        self.subset_size = subset_size

    def __iter__(self):
        indices = random.sample(range(len(self.dataset)), self.subset_size)
        return iter(indices)

    def __len__(self):
        return self.subset_size

subset_sampler = RandomSubsetSampler(val_dataset, 64)
subset_dataloader = DataLoader(val_dataset, sampler=subset_sampler)

NameError: name 'val_dataset' is not defined

### Meta Baseline Trainloop

In [39]:
# Define the Model
# model = get_cifar10_model()
model = LeNet()
model = model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

# Train the model
epochs = 20
lr = 0.001
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop
    for images, labels in tqdm(train_dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        # meta_net = get_cifar10_model()
        meta_net = LeNet()
        meta_net.load_state_dict(model.state_dict())

        meta_net = meta_net.to(device)

        optimizer_meta = torch.optim.Adam(meta_net.parameters())

        meta_net.train()
        
        y_f_hat = meta_net(images)
        cost = loss_fn_meta(y_f_hat, labels)
        eps = torch.zeros(cost.size(), requires_grad=True).to(device)
        l_f_meta = torch.sum(cost*eps)

        # meta_net.zero_grad()
        optimizer_meta.zero_grad()
        eps.retain_grad()
        l_f_meta.backward()
        optimizer_meta.step()

        meta_net.eval()

        # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
        # meta_net.update_params(lr, source_params=grads)
        
        val_images, val_labels = next(iter(subset_dataloader))
        # val_images, val_labels = next(iter(val_dataloader))
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)

        y_g_hat = meta_net(val_images)
        l_g_meta = loss_fn(y_g_hat, val_labels)

        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
        # print(grad_eps)

        with torch.no_grad():
            optimizer_meta.zero_grad()
            l_g_meta.backward()
            grad_eps = eps.grad
        
        # print(grad_eps)
        w_tilde = torch.clamp(grad_eps,min=0)
        # w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde
        
        # print(w)
        # break
        # Forward Pass
        outputs = model(images)
        loss = loss_fn_meta(outputs, labels)
        loss = torch.sum(loss*w)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")

print("--- %s seconds ---" % (time.time() - start_time))

100%|██████████| 782/782 [00:38<00:00, 20.54it/s]
100%|██████████| 149/149 [00:12<00:00, 11.65it/s] 
  5%|▌         | 1/20 [00:50<16:06, 50.86s/it]

Epoch: [1/20], Accuracy: 0.3925


100%|██████████| 782/782 [00:37<00:00, 20.92it/s]
100%|██████████| 149/149 [00:12<00:00, 11.79it/s] 
 10%|█         | 2/20 [01:40<15:06, 50.37s/it]

Epoch: [2/20], Accuracy: 0.4705


100%|██████████| 782/782 [00:37<00:00, 20.74it/s]
100%|██████████| 149/149 [00:12<00:00, 11.61it/s] 
 15%|█▌        | 3/20 [02:31<14:17, 50.45s/it]

Epoch: [3/20], Accuracy: 0.4973


100%|██████████| 782/782 [00:38<00:00, 20.46it/s]
100%|██████████| 149/149 [00:12<00:00, 11.67it/s] 
 20%|██        | 4/20 [03:22<13:30, 50.67s/it]

Epoch: [4/20], Accuracy: 0.5048


100%|██████████| 782/782 [00:38<00:00, 20.40it/s]
100%|██████████| 149/149 [00:12<00:00, 11.72it/s] 
 25%|██▌       | 5/20 [04:13<12:42, 50.80s/it]

Epoch: [5/20], Accuracy: 0.5098


100%|██████████| 782/782 [00:37<00:00, 20.67it/s]
100%|██████████| 149/149 [00:12<00:00, 11.74it/s] 
 30%|███       | 6/20 [05:04<11:49, 50.71s/it]

Epoch: [6/20], Accuracy: 0.5087


100%|██████████| 782/782 [00:37<00:00, 20.87it/s]
100%|██████████| 149/149 [00:12<00:00, 11.64it/s] 
 35%|███▌      | 7/20 [05:54<10:57, 50.57s/it]

Epoch: [7/20], Accuracy: 0.5324


100%|██████████| 782/782 [00:37<00:00, 20.86it/s]
100%|██████████| 149/149 [01:07<00:00,  2.20it/s] 
 40%|████      | 8/20 [07:39<13:35, 67.97s/it]

Epoch: [8/20], Accuracy: 0.5362


100%|██████████| 782/782 [00:38<00:00, 20.45it/s]
100%|██████████| 149/149 [00:12<00:00, 11.63it/s]
 45%|████▌     | 9/20 [08:30<11:29, 62.69s/it]

Epoch: [9/20], Accuracy: 0.5518


100%|██████████| 782/782 [00:37<00:00, 20.60it/s]
100%|██████████| 149/149 [00:12<00:00, 11.52it/s]
 50%|█████     | 10/20 [09:21<09:50, 59.05s/it]

Epoch: [10/20], Accuracy: 0.5537


100%|██████████| 782/782 [00:38<00:00, 20.37it/s]
100%|██████████| 149/149 [00:12<00:00, 11.64it/s] 
 55%|█████▌    | 11/20 [10:12<08:29, 56.64s/it]

Epoch: [11/20], Accuracy: 0.5647


100%|██████████| 782/782 [00:38<00:00, 20.19it/s]
100%|██████████| 149/149 [00:12<00:00, 11.50it/s]
 60%|██████    | 12/20 [11:04<07:21, 55.14s/it]

Epoch: [12/20], Accuracy: 0.5621


100%|██████████| 782/782 [00:38<00:00, 20.35it/s]
100%|██████████| 149/149 [00:12<00:00, 11.64it/s] 
 65%|██████▌   | 13/20 [11:55<06:17, 53.96s/it]

Epoch: [13/20], Accuracy: 0.5765


100%|██████████| 782/782 [00:38<00:00, 20.47it/s]
100%|██████████| 149/149 [00:12<00:00, 11.53it/s]
 70%|███████   | 14/20 [12:46<05:18, 53.10s/it]

Epoch: [14/20], Accuracy: 0.5733


100%|██████████| 782/782 [00:38<00:00, 20.42it/s]
100%|██████████| 149/149 [00:12<00:00, 11.64it/s] 
 75%|███████▌  | 15/20 [13:37<04:22, 52.50s/it]

Epoch: [15/20], Accuracy: 0.5783


100%|██████████| 782/782 [00:38<00:00, 20.25it/s]
100%|██████████| 149/149 [00:13<00:00, 11.30it/s]
 80%|████████  | 16/20 [14:29<03:29, 52.29s/it]

Epoch: [16/20], Accuracy: 0.5848


100%|██████████| 782/782 [00:38<00:00, 20.25it/s]
100%|██████████| 149/149 [00:12<00:00, 11.51it/s]
 85%|████████▌ | 17/20 [15:21<02:36, 52.07s/it]

Epoch: [17/20], Accuracy: 0.5947


100%|██████████| 782/782 [00:38<00:00, 20.36it/s]
100%|██████████| 149/149 [00:12<00:00, 11.62it/s] 
 90%|█████████ | 18/20 [16:12<01:43, 51.82s/it]

Epoch: [18/20], Accuracy: 0.5984


100%|██████████| 782/782 [00:37<00:00, 20.98it/s]
100%|██████████| 149/149 [00:12<00:00, 11.61it/s] 
 95%|█████████▌| 19/20 [17:02<00:51, 51.32s/it]

Epoch: [19/20], Accuracy: 0.5919


100%|██████████| 782/782 [00:37<00:00, 21.02it/s]
100%|██████████| 149/149 [00:12<00:00, 11.77it/s] 
100%|██████████| 20/20 [17:52<00:00, 53.62s/it]

Epoch: [20/20], Accuracy: 0.6049
--- 1072.4640719890594 seconds ---





# Milo Setup

### Load Data

In [17]:
num_classes = 10
class_data = []
for i in range(num_classes):
    with open(f"milo-base/class-data/class_{i}.pkl", "rb") as f:
        S = pickle.load(f)
        class_data.append(S)

In [16]:
num_sets = len(class_data[0])
data = []
for i in range(num_sets):
    S = []
    for j in range(num_classes):
        S.extend(class_data[j][i])
    data.append(S)

### Define Dataloader

In [15]:
class SubDataset(Dataset):
    def __init__(self, indices, dataset):
        self.indices = indices
        self.dataset = dataset

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        data_point = self.dataset[index]
        return data_point

## Milo Training Loop

In [21]:
# Define Model
model = LeNet()
model = model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Train the model
epochs = 20
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop

    sub_dataset = SubDataset(indices=data[epoch], dataset=train_dataset)
    subset_train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True)
    
    for images, labels in tqdm(subset_train_dataloader):

        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")

print("--- %s seconds ---" % (time.time() - start_time))

100%|██████████| 79/79 [00:00<00:00, 82.93it/s]
100%|██████████| 157/157 [00:13<00:00, 12.06it/s] 
  5%|▌         | 1/20 [00:13<04:25, 13.98s/it]

Epoch: [1/20], Accuracy: 0.2393


100%|██████████| 79/79 [00:00<00:00, 87.33it/s]
100%|██████████| 157/157 [00:12<00:00, 12.45it/s] 
 10%|█         | 2/20 [00:27<04:06, 13.71s/it]

Epoch: [2/20], Accuracy: 0.2537


100%|██████████| 79/79 [00:00<00:00, 87.85it/s]
100%|██████████| 157/157 [00:12<00:00, 12.42it/s] 
 15%|█▌        | 3/20 [00:41<03:51, 13.63s/it]

Epoch: [3/20], Accuracy: 0.2867


100%|██████████| 79/79 [00:00<00:00, 84.29it/s]
100%|██████████| 157/157 [00:12<00:00, 12.15it/s]
 20%|██        | 4/20 [00:54<03:39, 13.72s/it]

Epoch: [4/20], Accuracy: 0.2939


100%|██████████| 79/79 [00:00<00:00, 86.68it/s]
100%|██████████| 157/157 [00:12<00:00, 12.33it/s] 
 25%|██▌       | 5/20 [01:08<03:25, 13.70s/it]

Epoch: [5/20], Accuracy: 0.3160


100%|██████████| 79/79 [00:00<00:00, 84.82it/s]
100%|██████████| 157/157 [00:12<00:00, 12.35it/s] 
 30%|███       | 6/20 [01:22<03:11, 13.68s/it]

Epoch: [6/20], Accuracy: 0.3397


100%|██████████| 79/79 [00:00<00:00, 80.12it/s]
100%|██████████| 157/157 [00:13<00:00, 12.01it/s] 
 35%|███▌      | 7/20 [01:36<02:59, 13.80s/it]

Epoch: [7/20], Accuracy: 0.3366


100%|██████████| 79/79 [00:00<00:00, 80.35it/s]
100%|██████████| 157/157 [00:12<00:00, 12.11it/s] 
 40%|████      | 8/20 [01:50<02:46, 13.85s/it]

Epoch: [8/20], Accuracy: 0.3557


100%|██████████| 79/79 [00:00<00:00, 87.21it/s]
100%|██████████| 157/157 [00:12<00:00, 12.42it/s] 
 45%|████▌     | 9/20 [02:03<02:31, 13.76s/it]

Epoch: [9/20], Accuracy: 0.3709


100%|██████████| 79/79 [00:00<00:00, 88.80it/s]
100%|██████████| 157/157 [00:12<00:00, 12.45it/s] 
 50%|█████     | 10/20 [02:17<02:16, 13.68s/it]

Epoch: [10/20], Accuracy: 0.3620


100%|██████████| 79/79 [00:00<00:00, 88.00it/s]
100%|██████████| 157/157 [00:12<00:00, 12.43it/s] 
 55%|█████▌    | 11/20 [02:30<02:02, 13.64s/it]

Epoch: [11/20], Accuracy: 0.3781


100%|██████████| 79/79 [00:00<00:00, 84.63it/s]
100%|██████████| 157/157 [00:12<00:00, 12.15it/s] 
 60%|██████    | 12/20 [02:44<01:49, 13.70s/it]

Epoch: [12/20], Accuracy: 0.3872


100%|██████████| 79/79 [00:00<00:00, 83.30it/s]
100%|██████████| 157/157 [00:12<00:00, 12.22it/s] 
 65%|██████▌   | 13/20 [02:58<01:36, 13.73s/it]

Epoch: [13/20], Accuracy: 0.3881


100%|██████████| 79/79 [00:01<00:00, 78.85it/s]
100%|██████████| 157/157 [00:12<00:00, 12.36it/s]
 70%|███████   | 14/20 [03:12<01:22, 13.73s/it]

Epoch: [14/20], Accuracy: 0.3881


100%|██████████| 79/79 [00:01<00:00, 78.81it/s]
100%|██████████| 157/157 [00:12<00:00, 12.36it/s]
 75%|███████▌  | 15/20 [03:25<01:08, 13.72s/it]

Epoch: [15/20], Accuracy: 0.3705


100%|██████████| 79/79 [00:00<00:00, 80.62it/s]
100%|██████████| 157/157 [00:12<00:00, 12.37it/s] 
 80%|████████  | 16/20 [03:39<00:54, 13.71s/it]

Epoch: [16/20], Accuracy: 0.3971


100%|██████████| 79/79 [00:00<00:00, 81.52it/s]
100%|██████████| 157/157 [00:12<00:00, 12.23it/s] 
 85%|████████▌ | 17/20 [03:53<00:41, 13.74s/it]

Epoch: [17/20], Accuracy: 0.3985


100%|██████████| 79/79 [00:00<00:00, 86.42it/s]
100%|██████████| 157/157 [00:12<00:00, 12.45it/s]
 90%|█████████ | 18/20 [04:06<00:27, 13.68s/it]

Epoch: [18/20], Accuracy: 0.4113


100%|██████████| 79/79 [00:00<00:00, 80.38it/s]
100%|██████████| 157/157 [00:12<00:00, 12.38it/s]
 95%|█████████▌| 19/20 [04:20<00:13, 13.68s/it]

Epoch: [19/20], Accuracy: 0.4032


100%|██████████| 79/79 [00:00<00:00, 88.00it/s]
100%|██████████| 157/157 [00:12<00:00, 12.40it/s]
100%|██████████| 20/20 [04:34<00:00, 13.71s/it]

Epoch: [20/20], Accuracy: 0.4157
--- 274.18256092071533 seconds ---





In [25]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

split_ratio = 0.95 # 80% for test, 20% for validation
n_samples = len(test_dataset)
n_test = int(n_samples * split_ratio)
n_val = n_samples - n_test
test_dataset, val_dataset = random_split(test_dataset, [n_test, n_val])


# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [26]:
subset_sampler = RandomSubsetSampler(val_dataset, 64)
subset_dataloader = DataLoader(val_dataset, sampler=subset_sampler)

### Meta-Milo Training loop

In [27]:
# Define Model
model = LeNet()
model = model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

# Train the model
epochs = 20
lr = 0.001
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop
    sub_dataset = SubDataset(indices=data[epoch], dataset=train_dataset)
    # train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True, num_workers=2)
    train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True)
    
    for images, labels in tqdm(train_dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        # meta_net = get_cifar10_model()
        meta_net = LeNet()
        meta_net.load_state_dict(model.state_dict())

        meta_net = meta_net.to(device)

        optimizer_meta = torch.optim.Adam(meta_net.parameters())

        meta_net.train()
        
        y_f_hat = meta_net(images)
        cost = loss_fn_meta(y_f_hat, labels)
        eps = torch.zeros(cost.size(), requires_grad=True).to(device)
        l_f_meta = torch.sum(cost*eps)

        # meta_net.zero_grad()
        optimizer_meta.zero_grad()
        eps.retain_grad()
        l_f_meta.backward()
        optimizer_meta.step()

        meta_net.eval()

        # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
        # meta_net.update_params(lr, source_params=grads)
        
        val_images, val_labels = next(iter(subset_dataloader))
        # val_images, val_labels = next(iter(val_dataloader))
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)

        y_g_hat = meta_net(val_images)
        l_g_meta = loss_fn(y_g_hat, val_labels)

        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
        # print(grad_eps)

        with torch.no_grad():
            optimizer_meta.zero_grad()
            l_g_meta.backward()
            grad_eps = eps.grad
        
        # print(grad_eps)
        w_tilde = torch.clamp(grad_eps,min=0)
        # w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde
        
        # print(w)
        # break
        # Forward Pass
        outputs = model(images)
        loss = loss_fn_meta(outputs, labels)
        loss = torch.sum(loss*w)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")

print("--- %s seconds ---" % (time.time() - start_time))

100%|██████████| 79/79 [00:03<00:00, 19.99it/s]
100%|██████████| 149/149 [00:00<00:00, 159.63it/s]
  5%|▌         | 1/20 [00:04<01:32,  4.89s/it]

Epoch: [1/20], Accuracy: 0.2315


100%|██████████| 79/79 [00:02<00:00, 36.78it/s]
100%|██████████| 149/149 [00:00<00:00, 163.74it/s]
 10%|█         | 2/20 [00:07<01:08,  3.81s/it]

Epoch: [2/20], Accuracy: 0.2860


100%|██████████| 79/79 [00:02<00:00, 37.44it/s]
100%|██████████| 149/149 [00:00<00:00, 163.85it/s]
 15%|█▌        | 3/20 [00:10<00:58,  3.45s/it]

Epoch: [3/20], Accuracy: 0.3108


100%|██████████| 79/79 [00:02<00:00, 37.50it/s]
100%|██████████| 149/149 [00:00<00:00, 164.38it/s]
 20%|██        | 4/20 [00:13<00:52,  3.28s/it]

Epoch: [4/20], Accuracy: 0.3161


100%|██████████| 79/79 [00:02<00:00, 36.86it/s]
100%|██████████| 149/149 [00:00<00:00, 163.74it/s]
 25%|██▌       | 5/20 [00:17<00:47,  3.20s/it]

Epoch: [5/20], Accuracy: 0.3356


100%|██████████| 79/79 [00:02<00:00, 36.86it/s]
100%|██████████| 149/149 [00:00<00:00, 157.10it/s]
 30%|███       | 6/20 [00:20<00:44,  3.16s/it]

Epoch: [6/20], Accuracy: 0.3442


100%|██████████| 79/79 [00:02<00:00, 37.46it/s]
100%|██████████| 149/149 [00:00<00:00, 165.97it/s]
 35%|███▌      | 7/20 [00:23<00:40,  3.11s/it]

Epoch: [7/20], Accuracy: 0.3377


100%|██████████| 79/79 [00:02<00:00, 37.70it/s]
100%|██████████| 149/149 [00:00<00:00, 164.50it/s]
 40%|████      | 8/20 [00:26<00:36,  3.08s/it]

Epoch: [8/20], Accuracy: 0.3669


100%|██████████| 79/79 [00:02<00:00, 36.51it/s]
100%|██████████| 149/149 [00:00<00:00, 164.98it/s]
 45%|████▌     | 9/20 [00:29<00:33,  3.08s/it]

Epoch: [9/20], Accuracy: 0.3692


100%|██████████| 79/79 [00:02<00:00, 37.62it/s]
100%|██████████| 149/149 [00:00<00:00, 163.58it/s]
 50%|█████     | 10/20 [00:32<00:30,  3.06s/it]

Epoch: [10/20], Accuracy: 0.3881


100%|██████████| 79/79 [00:02<00:00, 37.72it/s]
100%|██████████| 149/149 [00:00<00:00, 163.88it/s]
 55%|█████▌    | 11/20 [00:35<00:27,  3.04s/it]

Epoch: [11/20], Accuracy: 0.3820


100%|██████████| 79/79 [00:02<00:00, 37.05it/s]
100%|██████████| 149/149 [00:00<00:00, 154.87it/s]
 60%|██████    | 12/20 [00:38<00:24,  3.06s/it]

Epoch: [12/20], Accuracy: 0.4081


100%|██████████| 79/79 [00:02<00:00, 37.63it/s]
100%|██████████| 149/149 [00:00<00:00, 163.85it/s]
 65%|██████▌   | 13/20 [00:41<00:21,  3.04s/it]

Epoch: [13/20], Accuracy: 0.3860


100%|██████████| 79/79 [00:02<00:00, 36.98it/s]
100%|██████████| 149/149 [00:00<00:00, 164.52it/s]
 70%|███████   | 14/20 [00:44<00:18,  3.04s/it]

Epoch: [14/20], Accuracy: 0.4075


100%|██████████| 79/79 [00:02<00:00, 37.72it/s]
100%|██████████| 149/149 [00:00<00:00, 164.81it/s]
 75%|███████▌  | 15/20 [00:47<00:15,  3.03s/it]

Epoch: [15/20], Accuracy: 0.4094


100%|██████████| 79/79 [00:02<00:00, 37.98it/s]
100%|██████████| 149/149 [00:00<00:00, 163.79it/s]
 80%|████████  | 16/20 [00:50<00:12,  3.02s/it]

Epoch: [16/20], Accuracy: 0.4195


100%|██████████| 79/79 [00:02<00:00, 36.08it/s]
100%|██████████| 149/149 [00:00<00:00, 162.64it/s]
 85%|████████▌ | 17/20 [00:53<00:09,  3.05s/it]

Epoch: [17/20], Accuracy: 0.4220


100%|██████████| 79/79 [00:02<00:00, 37.38it/s]
100%|██████████| 149/149 [00:00<00:00, 164.87it/s]
 90%|█████████ | 18/20 [00:56<00:06,  3.04s/it]

Epoch: [18/20], Accuracy: 0.4200


100%|██████████| 79/79 [00:02<00:00, 37.61it/s]
100%|██████████| 149/149 [00:00<00:00, 164.66it/s]
 95%|█████████▌| 19/20 [00:59<00:03,  3.03s/it]

Epoch: [19/20], Accuracy: 0.4281


100%|██████████| 79/79 [00:02<00:00, 36.94it/s]
100%|██████████| 149/149 [00:00<00:00, 163.73it/s]
100%|██████████| 20/20 [01:02<00:00,  3.13s/it]

Epoch: [20/20], Accuracy: 0.4261
--- 62.580097913742065 seconds ---



