In [2]:
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from tqdm import tqdm 
import time
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision.models.resnet import ResNet18_Weights
import pickle
import random
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import statistics
import matplotlib.pyplot as plt


seed = 42
torch.manual_seed(seed)

<torch._C.Generator at 0x7f793c2392b0>

In [3]:
import sys
sys.path.append('..')

In [4]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda:3" # change the available gpu number
else:
    device = "cpu"

In [5]:
subset_fraction = 0.1
num_runs = 1
epochs = 40
# model_name = "LeNet"
model_name = "resnet18"
submod_func = "facility-location"
data_dir = "../data"
milo_sub_base_dir = "../data/milo-data-gen/cifar10-dino-cls"

In [6]:
from models.LeNet_model import LeNet
from models.resent_models import get_resent101_model, get_resent18_model
from models.utils import SubDataset, RandomSubsetSampler

In [7]:
# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [21]:
# Load CIFAR10 datasets
trainset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)

split_ratio = 0.9

n_samples = len(test_dataset)
n_test = int(n_samples * split_ratio)
n_val = n_samples - n_test
testset, valset = random_split(test_dataset, [n_test, n_val])


# Create dataloaders
train_dataloader = DataLoader(trainset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(testset, batch_size=64, shuffle=False)
val_dataloader = DataLoader(valset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
subset_sampler = RandomSubsetSampler(valset, 64)
meta_dataloader = DataLoader(valset, sampler=subset_sampler)

In [17]:
submod_func = "facility-location"
metric = "cosine"

In [18]:
num_classes = 10
class_data = []
subset_fraction = 0.3
for i in range(num_classes):
    with open(f"{milo_sub_base_dir}/SGE-{metric}/{submod_func}/class-data-{subset_fraction}/class_{i}.pkl", "rb") as f:
        S = pickle.load(f)
        class_data.append(S)

In [19]:
num_sets = len(class_data[0])
data = []
for i in range(num_sets):
    S = []
    for j in range(num_classes):
        S.extend(class_data[j][i])
    data.append(S)

In [22]:
# Define Model
if model_name=="LeNet":
    model = LeNet()
elif model_name=="resnet18":
    model = get_resent18_model()
elif  model_name=="resnet101":
    model = get_resent101_model()

model = model.to(device)
acc_list = []
R=1

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

# Train the model
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):

    if epoch%R==0:
        sub_dataset = SubDataset(indices=data[epoch//R], dataset=trainset)
        subset_train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True)
    
    for images, labels in subset_train_dataloader:
        
        images = images.to(device)
        labels = labels.to(device)

        if model_name=="LeNet":
            meta_net = LeNet()
        elif model_name=="resnet18":
            meta_net = get_resent18_model()
        elif  model_name=="resnet101":
            meta_net = get_resent101_model()
        
        meta_net.load_state_dict(model.state_dict())
        meta_net = meta_net.to(device)
        optimizer_meta = torch.optim.Adam(meta_net.parameters())

        meta_net.train()
        
        y_f_hat = meta_net(images)
        cost = loss_fn_meta(y_f_hat, labels)
        eps = torch.zeros(cost.size(), requires_grad=True).to(device)
        l_f_meta = torch.sum(cost*eps) #f i,e

        # meta_net.zero_grad()
        optimizer_meta.zero_grad()
        eps.retain_grad()
        l_f_meta.backward()
        optimizer_meta.step()

        meta_net.eval()
        
        val_images, val_labels = next(iter(meta_dataloader))
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)

        y_g_hat = meta_net(val_images)
        l_g_meta = loss_fn(y_g_hat, val_labels)

        with torch.no_grad():
            optimizer_meta.zero_grad()
            l_g_meta.backward()
            grad_eps = eps.grad
        
        w_tilde = torch.clamp(grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde
        
        # Forward Pass
        outputs = model(images)
        loss = loss_fn_meta(outputs, labels)
        loss = torch.sum(loss*w)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_list.append(accuracy) 


100%|██████████| 141/141 [00:01<00:00, 77.47it/s]
100%|██████████| 141/141 [00:01<00:00, 78.86it/s]
100%|██████████| 141/141 [00:01<00:00, 80.25it/s]
100%|██████████| 141/141 [00:01<00:00, 81.32it/s]
100%|██████████| 141/141 [00:01<00:00, 81.53it/s]
100%|██████████| 141/141 [00:01<00:00, 80.96it/s]
100%|██████████| 141/141 [00:01<00:00, 80.44it/s]
100%|██████████| 141/141 [00:01<00:00, 80.26it/s]
100%|██████████| 141/141 [00:01<00:00, 74.22it/s]
100%|██████████| 141/141 [00:01<00:00, 81.05it/s]
100%|██████████| 141/141 [00:01<00:00, 80.62it/s]
100%|██████████| 141/141 [00:01<00:00, 80.83it/s]
100%|██████████| 141/141 [00:01<00:00, 80.73it/s]
100%|██████████| 141/141 [00:01<00:00, 81.14it/s]
100%|██████████| 141/141 [00:01<00:00, 80.70it/s]
100%|██████████| 141/141 [00:01<00:00, 81.42it/s]
100%|██████████| 141/141 [00:01<00:00, 75.08it/s]
100%|██████████| 141/141 [00:01<00:00, 80.88it/s]
100%|██████████| 141/141 [00:01<00:00, 79.70it/s]
100%|██████████| 141/141 [00:01<00:00, 80.25it/s]


### meta-milo

In [23]:
print(acc_list)

[0.3661111111111111, 0.44077777777777777, 0.49955555555555553, 0.5187777777777778, 0.5256666666666666, 0.5448888888888889, 0.5647777777777778, 0.5698888888888889, 0.5917777777777777, 0.6052222222222222, 0.6376666666666667, 0.6241111111111111, 0.6347777777777778, 0.6568888888888889, 0.6546666666666666, 0.661, 0.6761111111111111, 0.6602222222222223, 0.6622222222222223, 0.6893333333333334, 0.6777777777777778, 0.686, 0.686, 0.6961111111111111, 0.6658888888888889, 0.7047777777777777, 0.7133333333333334, 0.7023333333333334, 0.6986666666666667, 0.7231111111111111, 0.7243333333333334, 0.7223333333333334, 0.7211111111111111, 0.732, 0.7294444444444445, 0.7263333333333334, 0.733, 0.7331111111111112, 0.734, 0.7377777777777778]
