In [1]:
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from tqdm import tqdm 
import time
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision.models.resnet import ResNet18_Weights
import pickle
import random
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import statistics
import matplotlib.pyplot as plt


seed = 42
torch.manual_seed(seed)

<torch._C.Generator at 0x7f684bd2d130>

In [2]:
import sys
sys.path.append('..')

In [3]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda:3" # change the available gpu number
else:
    device = "cpu"

In [4]:
subset_fraction = 0.1
num_runs = 1
epochs = 40
# model_name = "LeNet"
model_name = "resnet18"
submod_func = "facility-location"
data_dir = "../data"

In [9]:
from models.LeNet_model import LeNet
from models.resent_models import get_resent101_model, get_resent18_model
from models.utils import RandomSubsetSampler

In [6]:
# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [13]:
fullset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)

trainset = fullset

validation_set_fraction = 0.1
num_fulltrn = len(fullset)
num_val = int(num_fulltrn * validation_set_fraction)
num_trn = num_fulltrn - num_val
trainset, valset = random_split(fullset, [num_trn, num_val], generator=torch.Generator().manual_seed(seed))

Files already downloaded and verified
Files already downloaded and verified


In [14]:
train_dataloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
val_dataloader = DataLoader(valset, batch_size=64, shuffle=False, num_workers=2)

In [15]:
subset_sampler = RandomSubsetSampler(valset, 64)
subset_dataloader = DataLoader(valset, sampler=subset_sampler)

In [16]:
seed = 42
torch.manual_seed(seed)

if model_name=="LeNet":
    model = LeNet()
elif model_name=="resnet18":
    model = get_resent18_model()
elif  model_name=="resnet101":
    model = get_resent101_model()

model = model.to(device)
acc_list = []

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

# Train the model
model.train()
start_time = time.time()
epochs = 100
for epoch in tqdm(range(epochs)):
    # Train loop
    for images, labels in train_dataloader:
        
        images = images.to(device)
        labels = labels.to(device)

        if model_name=="LeNet":
            meta_net = LeNet()
        elif model_name=="resnet18":
            meta_net = get_resent18_model()
        elif  model_name=="resnet101":
            meta_net = get_resent101_model()
        
        meta_net.load_state_dict(model.state_dict())
        meta_net = meta_net.to(device)

        optimizer_meta = torch.optim.Adam(meta_net.parameters())

        meta_net.train()
        
        y_f_hat = meta_net(images)#predictions
        cost = loss_fn_meta(y_f_hat, labels)#normal loss
        eps = torch.zeros(cost.size(), requires_grad=True).to(device)
        l_f_meta = torch.sum(cost*eps)

        # meta_net.zero_grad()
        optimizer_meta.zero_grad()
        #what is retain grad
        eps.retain_grad()
        l_f_meta.backward()
        optimizer_meta.step()

        meta_net.eval()

        # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
        # meta_net.update_params(lr, source_params=grads)
        
        val_images, val_labels = next(iter(subset_dataloader))
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)

        y_g_hat = meta_net(val_images)
        l_g_meta = loss_fn(y_g_hat, val_labels)

        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
        # print(grad_eps)

        with torch.no_grad():
            optimizer_meta.zero_grad()
            l_g_meta.backward()
            grad_eps = eps.grad
        
        # print(grad_eps)
        w_tilde = torch.clamp(grad_eps,min=0)
        # w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde
        
        # print(w)
        # break
        # Forward Pass
        outputs = model(images)
        loss = loss_fn_meta(outputs, labels)
        loss = torch.sum(loss*w)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_list.append(accuracy)

print(f"Accuracy: {accuracy:.4f}")

  4%|▍         | 4/100 [07:31<3:00:05, 112.56s/it]

In [None]:
print(acc_list)