In [2]:
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from tqdm import tqdm 
import time
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision.models.resnet import ResNet18_Weights
import pickle
import random
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import statistics
import matplotlib.pyplot as plt


seed = 42
torch.manual_seed(seed)

<torch._C.Generator at 0x7f63e915aa10>

In [3]:
import sys
sys.path.append('..')

In [4]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda:3" # change the available gpu number
else:
    device = "cpu"

In [5]:
subset_fraction = 0.1
num_runs = 1
epochs = 40
# model_name = "LeNet"
model_name = "resnet18"
submod_func = "facility-location"
data_dir = "../data"
milo_sub_base_dir = "../data/milo-data-gen/cifar10-dino-cls"

In [6]:
from models.LeNet_model import LeNet
from models.resent_models import get_resent101_model, get_resent18_model
from models.utils import SubDataset

In [7]:
# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [8]:
trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)

train_dataloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
submod_func = "disparity-sum"
metric = "cosine"

In [10]:
num_classes = 10
class_data = []
subset_fraction = 0.3
for i in range(num_classes):
    with open(f"{milo_sub_base_dir}/SGE-{metric}/{submod_func}/class-data-{subset_fraction}/class_{i}.pkl", "rb") as f:
        S = pickle.load(f)
        class_data.append(S)

In [11]:
num_sets = len(class_data[0])
data = []
for i in range(num_sets):
    S = []
    for j in range(num_classes):
        S.extend(class_data[j][i])
    data.append(S)

In [12]:
torch.manual_seed(42)
num_runs = 1
R = 1

# Define Model
if model_name=="LeNet":
    model = LeNet()
elif model_name=="resnet18":
    model = get_resent18_model()
elif  model_name=="resnet101":
    model = get_resent101_model()

model = model.to(device)
acc_list = []

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Train the model
model.train()

for epoch in tqdm(range(epochs)):
    
    # Train loop
    if epoch%R==0:
        sub_dataset = SubDataset(indices=data[epoch//R], dataset=trainset)
        subset_train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True)
        
    for images, labels in subset_train_dataloader:

        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    acc_list.append(accuracy)

  2%|▎         | 1/40 [00:09<06:15,  9.63s/it]


KeyboardInterrupt: 

In [19]:
acc_list = [0.4304, 0.463, 0.5157, 0.5679, 0.5817, 0.6092, 0.6314, 0.6607, 0.649, 0.6398, 0.6674, 0.6698, 0.6851, 0.7083, 0.7057, 0.6902, 0.7114, 0.7116, 0.7289, 0.7323, 0.7311, 0.7321, 0.7366, 0.7453, 0.7436, 0.7436, 0.7484, 0.7451, 0.746, 0.7415, 0.7463, 0.751, 0.755, 0.746, 0.7578, 0.7574, 0.7626, 0.7712, 0.7661, 0.7582]