In [1]:
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from tqdm import tqdm 
import time
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision.models.resnet import ResNet18_Weights
import pickle
import random
import statistics


seed = 42
torch.manual_seed(seed)

<torch._C.Generator at 0x7f634e052570>

In [15]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda:5" # change the available gpu number
else:
    device = "cpu"

In [1]:
subset_fraction = 0.3
num_runs = 5
split_ratio = 0.9
epochs = 20

In [17]:
print(device)

cuda:5


### Load Resent Model 

In [4]:
def get_resent18_model(num_classes=10):
    model = torchvision.models.resnet18(weights=None)  # Use 'weights' for pretrained models
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    return model

def get_resent101_model(num_classes=10):
    model = torchvision.models.resnet101(weights=None)  # Use 'weights' for pretrained models
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    return model

In [5]:
model = get_resent101_model()
print(type(model))

<class 'torchvision.models.resnet.ResNet'>


In [6]:
# # Freeze pre-trained layers
# for param in model.parameters():
#     param.requires_grad = False

# # Unfreeze some layers for fine-tuning
# for param in model.layer4.parameters():
#     param.requires_grad = True

In [7]:
# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [8]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
print(train_dataset[0][0].shape)

torch.Size([3, 32, 32])


## LeNet Model Definition

In [13]:
class LeNet(nn.Module):    
    def __init__(self, out_classes=10):
        super(LeNet, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(400,120),  #in_features = 16 x5x5 
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84, out_classes),
            # nn.Softmax(dim=1)
        )
        
    def forward(self,x): 
        a1=self.feature_extractor(x)
        # print(a1.shape)
        a1 = torch.flatten(a1,1)
        a2=self.classifier(a1)
        return a2

# Baseline Model Training

### Basic Train Loop

In [None]:
seed = 42
torch.manual_seed(seed)

time_per_run = []
acc_per_run = []

for i in range(num_runs):
    # Define the Model
    model = get_resent101_model(10)
    model = model.to(device)

    # Define optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = nn.CrossEntropyLoss()

    # Train the model
    model.train()
    start_time = time.time()
    for epoch in tqdm(range(epochs)):
        # Train loop
        for images, labels in train_dataloader:

            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            
            # Backward pass and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
    time_taken = time.time() - start_time   
    time_per_run.append(time_taken)  
    print("--- %s seconds ---" % (time_taken))

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_per_run.append(accuracy)
    print(f"Accuracy: {accuracy:.4f}")
    # print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")

acc_mean = statistics.mean(acc_per_run)
acc_std = statistics.stdev(acc_per_run)

time_mean = statistics.mean(time_per_run)
time_std = statistics.stdev(time_per_run)

print(f"mean accuracy:{acc_mean}, std accuracy:{acc_std}")
print(f"mean accuracy:{time_mean}, std accuracy:{time_std}")


# Reweight Model Training

In [None]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

n_samples = len(test_dataset)
n_test = int(n_samples * split_ratio)
n_val = n_samples - n_test
test_dataset, val_dataset = random_split(test_dataset, [n_test, n_val])


# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

### Random Sampler for Sampling Validation Data

In [None]:
class RandomSubsetSampler(torch.utils.data.Sampler):
    def __init__(self, dataset, subset_size):
        self.dataset = dataset
        self.subset_size = subset_size

    def __iter__(self):
        indices = random.sample(range(len(self.dataset)), self.subset_size)
        return iter(indices)

    def __len__(self):
        return self.subset_size

subset_sampler = RandomSubsetSampler(val_dataset, 64)
subset_dataloader = DataLoader(val_dataset, sampler=subset_sampler)

### Meta Baseline Trainloop

In [None]:
seed = 42
torch.manual_seed(seed)

time_per_run = []
acc_per_run = []

for i in range(num_runs):
    # Define the Model
    model = LeNet()
    model = model.to(device)

    # Define optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = nn.CrossEntropyLoss()
    loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

    # Train the model
    model.train()
    start_time = time.time()
    for epoch in tqdm(range(epochs)):
        # Train loop
        for images, labels in train_dataloader:
            
            images = images.to(device)
            labels = labels.to(device)

            # meta_net = get_cifar10_model()
            meta_net = LeNet()
            meta_net.load_state_dict(model.state_dict())

            meta_net = meta_net.to(device)

            optimizer_meta = torch.optim.Adam(meta_net.parameters())

            meta_net.train()
            
            y_f_hat = meta_net(images)
            cost = loss_fn_meta(y_f_hat, labels)
            eps = torch.zeros(cost.size(), requires_grad=True).to(device)
            l_f_meta = torch.sum(cost*eps)

            # meta_net.zero_grad()
            optimizer_meta.zero_grad()
            eps.retain_grad()
            l_f_meta.backward()
            optimizer_meta.step()

            meta_net.eval()

            # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
            # meta_net.update_params(lr, source_params=grads)
            
            val_images, val_labels = next(iter(subset_dataloader))
            # val_images, val_labels = next(iter(val_dataloader))
            val_images = val_images.to(device)
            val_labels = val_labels.to(device)

            y_g_hat = meta_net(val_images)
            l_g_meta = loss_fn(y_g_hat, val_labels)

            # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
            # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
            # print(grad_eps)

            with torch.no_grad():
                optimizer_meta.zero_grad()
                l_g_meta.backward()
                grad_eps = eps.grad
            
            # print(grad_eps)
            w_tilde = torch.clamp(grad_eps,min=0)
            # w_tilde = torch.clamp(-grad_eps,min=0)
            norm_c = torch.sum(w_tilde)

            if norm_c != 0:
                w = w_tilde / norm_c
            else:
                w = w_tilde
            
            # print(w)
            # break
            # Forward Pass
            outputs = model(images)
            loss = loss_fn_meta(outputs, labels)
            loss = torch.sum(loss*w)
            
            # Backward pass and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    time_taken = time.time() - start_time   
    time_per_run.append(time_taken)  
    print("--- %s seconds ---" % (time_taken))

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_per_run.append(accuracy)
    print(f"Accuracy: {accuracy:.4f}")



# Milo Setup

### Load Data

In [None]:
num_classes = 10
class_data = []
subset_fraction = 0.3
for i in range(num_classes):
    with open(f"milo-base/class-data-{subset_fraction}/class_{i}.pkl", "rb") as f:
        S = pickle.load(f)
        class_data.append(S)

In [None]:
num_sets = len(class_data[0])
data = []
for i in range(num_sets):
    S = []
    for j in range(num_classes):
        S.extend(class_data[j][i])
    data.append(S)

In [None]:
print(len(data[0]))

### Define Dataloader

In [None]:
class SubDataset(Dataset):
    def __init__(self, indices, dataset):
        self.indices = indices
        self.dataset = dataset

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        data_point = self.dataset[index]
        return data_point

## Milo Training Loop

In [None]:
time_per_run = []
acc_per_run = []

for i in range(num_runs):
    # Define Model
    model = LeNet()
    model = model.to(device)

    # Define optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = nn.CrossEntropyLoss()

    # Train the model
    model.train()
    R = 1
    start_time = time.time()
    for epoch in tqdm(range(epochs)):
        
        # Train loop
        if epoch%R==0:
            sub_dataset = SubDataset(indices=data[epoch//R], dataset=train_dataset)
            subset_train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True)
            
        for images, labels in subset_train_dataloader:

            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            
            # Backward pass and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    time_taken = time.time() - start_time   
    time_per_run.append(time_taken)  
    print("--- %s seconds ---" % (time_taken))

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_per_run.append(accuracy)
    print(f"Accuracy: {accuracy:.4f}")


In [None]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

n_samples = len(test_dataset)
n_test = int(n_samples * split_ratio)
n_val = n_samples - n_test
test_dataset, val_dataset = random_split(test_dataset, [n_test, n_val])


# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
subset_sampler = RandomSubsetSampler(val_dataset, 64)
subset_dataloader = DataLoader(val_dataset, sampler=subset_sampler)

### Meta-Milo Training loop

In [None]:
time_per_run = []
acc_per_run = []

for i in range(num_runs):
    # Define Model
    model = LeNet()
    model = model.to(device)

    # Define optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = nn.CrossEntropyLoss()
    loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

    # Train the model
    model.train()
    start_time = time.time()
    for epoch in tqdm(range(epochs)):
        # Train loop

        if epoch%R==0:
            sub_dataset = SubDataset(indices=data[epoch//R], dataset=train_dataset)
            # train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True, num_workers=2)
            train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True)
        
        for images, labels in train_dataloader:
            
            images = images.to(device)
            labels = labels.to(device)

            # meta_net = get_cifar10_model()
            meta_net = LeNet()
            meta_net.load_state_dict(model.state_dict())

            meta_net = meta_net.to(device)

            optimizer_meta = torch.optim.Adam(meta_net.parameters())

            meta_net.train()
            
            y_f_hat = meta_net(images)
            cost = loss_fn_meta(y_f_hat, labels)
            eps = torch.zeros(cost.size(), requires_grad=True).to(device)
            l_f_meta = torch.sum(cost*eps)

            # meta_net.zero_grad()
            optimizer_meta.zero_grad()
            eps.retain_grad()
            l_f_meta.backward()
            optimizer_meta.step()

            meta_net.eval()

            # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
            # meta_net.update_params(lr, source_params=grads)
            
            val_images, val_labels = next(iter(subset_dataloader))
            # val_images, val_labels = next(iter(val_dataloader))
            val_images = val_images.to(device)
            val_labels = val_labels.to(device)

            y_g_hat = meta_net(val_images)
            l_g_meta = loss_fn(y_g_hat, val_labels)

            # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
            # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
            # print(grad_eps)

            with torch.no_grad():
                optimizer_meta.zero_grad()
                l_g_meta.backward()
                grad_eps = eps.grad
            
            # print(grad_eps)
            w_tilde = torch.clamp(grad_eps,min=0)
            # w_tilde = torch.clamp(-grad_eps,min=0)
            norm_c = torch.sum(w_tilde)

            if norm_c != 0:
                w = w_tilde / norm_c
            else:
                w = w_tilde
            
            # print(w)
            # break
            # Forward Pass
            outputs = model(images)
            loss = loss_fn_meta(outputs, labels)
            loss = torch.sum(loss*w)
            
            # Backward pass and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    time_taken = time.time() - start_time   
    time_per_run.append(time_taken)  
    print("--- %s seconds ---" % (time_taken))

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_per_run.append(accuracy)
    print(f"Accuracy: {accuracy:.4f}")

## Random Subset

In [9]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)


subset_size = int(subset_fraction * len(train_dataset))  # 30% of the full dataset
subset_indices = torch.randperm(len(train_dataset))[:subset_size]  # Randomly select indices
subset_dataset = torch.utils.data.Subset(train_dataset, subset_indices)  # Create the subset


# Create dataloaders
train_dataloader = DataLoader(subset_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
print(device)

cuda:5


In [22]:
seed = 42
torch.manual_seed(seed)

time_per_run = []
acc_per_run = []

for i in range(num_runs):
    # Define the Model
    model = LeNet()
    model = model.to(device)

    # Define optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = nn.CrossEntropyLoss()

    # Train the model
    model.train()
    start_time = time.time()
    for epoch in tqdm(range(epochs)):
        # Train loop
        for images, labels in train_dataloader:

            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            
            # Backward pass and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    time_taken = time.time() - start_time   
    time_per_run.append(time_taken)  
    print("--- %s seconds ---" % (time_taken))

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_per_run.append(accuracy)
    print(f"Accuracy: {accuracy:.4f}")
    # print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")


100%|██████████| 235/235 [00:02<00:00, 93.38it/s] 
100%|██████████| 235/235 [00:02<00:00, 97.50it/s] 
100%|██████████| 235/235 [00:02<00:00, 100.48it/s]
100%|██████████| 235/235 [00:02<00:00, 95.77it/s]
100%|██████████| 235/235 [00:02<00:00, 87.00it/s]
100%|██████████| 235/235 [00:02<00:00, 80.80it/s]
100%|██████████| 235/235 [00:02<00:00, 98.00it/s] 
100%|██████████| 235/235 [00:02<00:00, 104.23it/s]
100%|██████████| 235/235 [00:02<00:00, 99.71it/s] 
100%|██████████| 235/235 [00:02<00:00, 84.56it/s]
100%|██████████| 235/235 [00:02<00:00, 78.39it/s] 
100%|██████████| 235/235 [00:02<00:00, 98.65it/s] 
100%|██████████| 235/235 [00:02<00:00, 101.49it/s]
100%|██████████| 235/235 [00:02<00:00, 91.27it/s]
100%|██████████| 235/235 [00:02<00:00, 82.49it/s]
100%|██████████| 235/235 [00:02<00:00, 97.66it/s]
100%|██████████| 235/235 [00:02<00:00, 98.85it/s]
100%|██████████| 235/235 [00:02<00:00, 98.26it/s] 
100%|██████████| 235/235 [00:02<00:00, 90.93it/s] 
100%|██████████| 235/235 [00:02<00:00, 

--- 50.73576307296753 seconds ---


100%|██████████| 157/157 [00:01<00:00, 133.48it/s]


Accuracy: 0.5683


100%|██████████| 235/235 [00:02<00:00, 94.88it/s] 
100%|██████████| 235/235 [00:02<00:00, 90.04it/s] 
100%|██████████| 235/235 [00:02<00:00, 80.06it/s] 
100%|██████████| 235/235 [00:02<00:00, 79.08it/s] 
100%|██████████| 235/235 [00:02<00:00, 93.72it/s] 
100%|██████████| 235/235 [00:02<00:00, 96.93it/s] 
100%|██████████| 235/235 [00:02<00:00, 86.91it/s]
100%|██████████| 235/235 [00:03<00:00, 74.59it/s]
100%|██████████| 235/235 [00:02<00:00, 90.88it/s]
100%|██████████| 235/235 [00:02<00:00, 98.90it/s] 
100%|██████████| 235/235 [00:02<00:00, 97.53it/s] 
100%|██████████| 235/235 [00:02<00:00, 88.46it/s] 
100%|██████████| 235/235 [00:02<00:00, 85.06it/s]
100%|██████████| 235/235 [00:02<00:00, 87.82it/s] 
100%|██████████| 235/235 [00:02<00:00, 93.45it/s] 
100%|██████████| 235/235 [00:02<00:00, 99.33it/s] 
100%|██████████| 235/235 [00:02<00:00, 87.18it/s]
100%|██████████| 235/235 [00:02<00:00, 84.33it/s]
100%|██████████| 235/235 [00:02<00:00, 92.64it/s] 
100%|██████████| 235/235 [00:02<00:00

--- 52.57818102836609 seconds ---


100%|██████████| 157/157 [00:01<00:00, 139.99it/s]


Accuracy: 0.5755


100%|██████████| 235/235 [00:02<00:00, 89.82it/s]
100%|██████████| 235/235 [00:03<00:00, 76.64it/s]
100%|██████████| 235/235 [00:02<00:00, 87.76it/s] 
100%|██████████| 235/235 [00:02<00:00, 99.43it/s] 
100%|██████████| 235/235 [00:02<00:00, 95.99it/s] 
100%|██████████| 235/235 [00:02<00:00, 86.20it/s]
100%|██████████| 235/235 [00:03<00:00, 74.72it/s]
100%|██████████| 235/235 [00:02<00:00, 84.03it/s] 
100%|██████████| 235/235 [00:02<00:00, 90.97it/s] 
100%|██████████| 235/235 [00:02<00:00, 96.20it/s] 
100%|██████████| 235/235 [00:02<00:00, 88.57it/s]
100%|██████████| 235/235 [00:02<00:00, 84.12it/s]
100%|██████████| 235/235 [00:02<00:00, 95.42it/s] 
100%|██████████| 235/235 [00:02<00:00, 99.21it/s] 
100%|██████████| 235/235 [00:02<00:00, 94.94it/s] 
100%|██████████| 235/235 [00:02<00:00, 85.14it/s] 
100%|██████████| 235/235 [00:02<00:00, 85.64it/s] 
100%|██████████| 235/235 [00:02<00:00, 100.19it/s]
100%|██████████| 235/235 [00:02<00:00, 101.65it/s]
100%|██████████| 235/235 [00:02<00:00

--- 52.58237981796265 seconds ---


100%|██████████| 157/157 [00:01<00:00, 115.95it/s]


Accuracy: 0.5592


100%|██████████| 235/235 [00:02<00:00, 82.89it/s]
100%|██████████| 235/235 [00:02<00:00, 89.00it/s] 
100%|██████████| 235/235 [00:02<00:00, 94.40it/s]
100%|██████████| 235/235 [00:02<00:00, 89.01it/s]
100%|██████████| 235/235 [00:02<00:00, 90.74it/s]
100%|██████████| 235/235 [00:02<00:00, 93.86it/s] 
100%|██████████| 235/235 [00:02<00:00, 102.42it/s]
100%|██████████| 235/235 [00:02<00:00, 97.48it/s] 
100%|██████████| 235/235 [00:02<00:00, 91.14it/s] 
100%|██████████| 235/235 [00:02<00:00, 83.73it/s]
100%|██████████| 235/235 [00:02<00:00, 90.05it/s]
100%|██████████| 235/235 [00:02<00:00, 98.73it/s] 
100%|██████████| 235/235 [00:02<00:00, 93.01it/s] 
100%|██████████| 235/235 [00:02<00:00, 89.21it/s]
100%|██████████| 235/235 [00:02<00:00, 87.83it/s] 
100%|██████████| 235/235 [00:02<00:00, 97.80it/s] 
100%|██████████| 235/235 [00:02<00:00, 103.49it/s]
100%|██████████| 235/235 [00:02<00:00, 85.42it/s]
100%|██████████| 235/235 [00:02<00:00, 82.45it/s]
100%|██████████| 235/235 [00:02<00:00, 9

--- 51.550086975097656 seconds ---


100%|██████████| 157/157 [00:01<00:00, 122.83it/s]


Accuracy: 0.5867


100%|██████████| 235/235 [00:02<00:00, 100.53it/s]
100%|██████████| 235/235 [00:02<00:00, 88.19it/s]
100%|██████████| 235/235 [00:02<00:00, 84.81it/s]
100%|██████████| 235/235 [00:02<00:00, 92.11it/s] 
100%|██████████| 235/235 [00:02<00:00, 100.07it/s]
100%|██████████| 235/235 [00:02<00:00, 100.96it/s]
100%|██████████| 235/235 [00:02<00:00, 86.25it/s]
100%|██████████| 235/235 [00:02<00:00, 81.32it/s]
100%|██████████| 235/235 [00:02<00:00, 93.83it/s] 
100%|██████████| 235/235 [00:02<00:00, 97.21it/s] 
100%|██████████| 235/235 [00:02<00:00, 92.76it/s] 
100%|██████████| 235/235 [00:02<00:00, 85.40it/s]
100%|██████████| 235/235 [00:02<00:00, 85.53it/s] 
100%|██████████| 235/235 [00:02<00:00, 96.30it/s] 
100%|██████████| 235/235 [00:02<00:00, 86.30it/s] 
100%|██████████| 235/235 [00:02<00:00, 82.47it/s]
100%|██████████| 235/235 [00:02<00:00, 79.42it/s]
100%|██████████| 235/235 [00:02<00:00, 89.21it/s] 
100%|██████████| 235/235 [00:02<00:00, 98.06it/s] 
100%|██████████| 235/235 [00:02<00:00,

--- 52.40761351585388 seconds ---


100%|██████████| 157/157 [00:01<00:00, 114.33it/s]

Accuracy: 0.5793





In [23]:
print(time_per_run)
print(acc_per_run)

[50.73576307296753, 52.57818102836609, 52.58237981796265, 51.550086975097656, 52.40761351585388]
[0.5683, 0.5755, 0.5592, 0.5867, 0.5793]


## Random Subset Meta

In [1]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)


subset_size = int(subset_fraction * len(train_dataset))  # 30% of the full dataset
subset_indices = torch.randperm(len(train_dataset))[:subset_size]  # Randomly select indices
subset_dataset = torch.utils.data.Subset(train_dataset, subset_indices)  # Create the subset


# Create dataloaders
train_dataloader = DataLoader(subset_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

NameError: name 'datasets' is not defined

In [None]:
seed = 42
torch.manual_seed(seed)

for i in range(num_runs):
    # Define the Model
    model = LeNet()
    model = model.to(device)

    # Define optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = nn.CrossEntropyLoss()
    loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

    # Train the model
    model.train()
    start_time = time.time()
    for epoch in tqdm(range(epochs)):
        # Train loop
        for images, labels in train_dataloader:
            
            images = images.to(device)
            labels = labels.to(device)

            # meta_net = get_cifar10_model()
            meta_net = LeNet()
            meta_net.load_state_dict(model.state_dict())

            meta_net = meta_net.to(device)

            optimizer_meta = torch.optim.Adam(meta_net.parameters())

            meta_net.train()
            
            y_f_hat = meta_net(images)
            cost = loss_fn_meta(y_f_hat, labels)
            eps = torch.zeros(cost.size(), requires_grad=True).to(device)
            l_f_meta = torch.sum(cost*eps)

            # meta_net.zero_grad()
            optimizer_meta.zero_grad()
            eps.retain_grad()
            l_f_meta.backward()
            optimizer_meta.step()

            meta_net.eval()

            # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
            # meta_net.update_params(lr, source_params=grads)
            
            val_images, val_labels = next(iter(subset_dataloader))
            # val_images, val_labels = next(iter(val_dataloader))
            val_images = val_images.to(device)
            val_labels = val_labels.to(device)

            y_g_hat = meta_net(val_images)
            l_g_meta = loss_fn(y_g_hat, val_labels)

            # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
            # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
            # print(grad_eps)

            with torch.no_grad():
                optimizer_meta.zero_grad()
                l_g_meta.backward()
                grad_eps = eps.grad
            
            # print(grad_eps)
            w_tilde = torch.clamp(grad_eps,min=0)
            # w_tilde = torch.clamp(-grad_eps,min=0)
            norm_c = torch.sum(w_tilde)

            if norm_c != 0:
                w = w_tilde / norm_c
            else:
                w = w_tilde
            
            # print(w)
            # break
            # Forward Pass
            outputs = model(images)
            loss = loss_fn_meta(outputs, labels)
            loss = torch.sum(loss*w)
            
            # Backward pass and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    
    time_taken = time.time() - start_time   
    time_per_run.append(time_taken)  
    print("--- %s seconds ---" % (time_taken))

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_per_run.append(accuracy)
    print(f"Accuracy: {accuracy:.4f}")