In [41]:
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from tqdm import tqdm 
import time
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision.models.resnet import ResNet18_Weights
import torchviz
import pickle
import random

In [2]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"

### Load Resent Model 

In [3]:
def get_cifar10_model():
    model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)  # Use 'weights' for pretrained models
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)
    return model

In [4]:
model = get_cifar10_model()
print(type(model))

<class 'torchvision.models.resnet.ResNet'>


In [8]:
# Freeze pre-trained layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze some layers for fine-tuning
for param in model.layer4.parameters():
    param.requires_grad = True

In [5]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

In [5]:
# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [7]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [03:33<00:00, 798661.74it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## LeNet Model Definition

In [3]:
class LeNet(nn.Module):    
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(400,120),  #in_features = 16 x5x5 
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10),
            nn.Softmax(dim=1)

        )
        
    def forward(self,x): 
        a1=self.feature_extractor(x)
        # print(a1.shape)
        a1 = torch.flatten(a1,1)
        a2=self.classifier(a1)
        return a2

# Baseline Model Training

In [64]:
model = LeNet()
model = model.to(device)

### Basic Train Loop

In [68]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Train the model
epochs = 20
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop
    for images, labels in tqdm(train_dataloader):

        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")

print("--- %s seconds ---" % (time.time() - start_time))

100%|██████████| 782/782 [00:19<00:00, 39.53it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.51it/s]
  5%|▌         | 1/20 [00:32<10:14, 32.34s/it]

Epoch: [1/20], Accuracy: 0.3640


100%|██████████| 782/782 [00:18<00:00, 41.80it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.56it/s]
 10%|█         | 2/20 [01:03<09:30, 31.68s/it]

Epoch: [2/20], Accuracy: 0.3962


100%|██████████| 782/782 [00:18<00:00, 42.51it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.45it/s]
 15%|█▌        | 3/20 [01:34<08:53, 31.37s/it]

Epoch: [3/20], Accuracy: 0.4343


100%|██████████| 782/782 [00:18<00:00, 42.18it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.58it/s]
 20%|██        | 4/20 [02:05<08:19, 31.24s/it]

Epoch: [4/20], Accuracy: 0.4395


100%|██████████| 782/782 [00:18<00:00, 42.73it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.53it/s]
 25%|██▌       | 5/20 [02:36<07:46, 31.09s/it]

Epoch: [5/20], Accuracy: 0.4738


100%|██████████| 782/782 [00:18<00:00, 42.62it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.51it/s]
 30%|███       | 6/20 [03:07<07:14, 31.03s/it]

Epoch: [6/20], Accuracy: 0.4839


100%|██████████| 782/782 [00:18<00:00, 41.32it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.42it/s]
 35%|███▌      | 7/20 [03:38<06:45, 31.21s/it]

Epoch: [7/20], Accuracy: 0.4826


100%|██████████| 782/782 [00:18<00:00, 42.34it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.52it/s]
 40%|████      | 8/20 [04:09<06:13, 31.15s/it]

Epoch: [8/20], Accuracy: 0.4939


100%|██████████| 782/782 [00:18<00:00, 42.67it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.55it/s]
 45%|████▌     | 9/20 [04:40<05:41, 31.05s/it]

Epoch: [9/20], Accuracy: 0.5017


100%|██████████| 782/782 [00:18<00:00, 42.82it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.56it/s]
 50%|█████     | 10/20 [05:11<05:09, 30.97s/it]

Epoch: [10/20], Accuracy: 0.5002


100%|██████████| 782/782 [00:18<00:00, 42.77it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.61it/s]
 55%|█████▌    | 11/20 [05:42<04:38, 30.90s/it]

Epoch: [11/20], Accuracy: 0.5049


100%|██████████| 782/782 [00:18<00:00, 42.81it/s] 
100%|██████████| 157/157 [01:22<00:00,  1.89it/s]
 60%|██████    | 12/20 [07:23<06:58, 52.29s/it]

Epoch: [12/20], Accuracy: 0.5297


100%|██████████| 782/782 [00:24<00:00, 32.26it/s] 
100%|██████████| 157/157 [01:56<00:00,  1.35it/s]
 65%|██████▌   | 13/20 [09:44<09:13, 79.11s/it]

Epoch: [13/20], Accuracy: 0.5336


100%|██████████| 782/782 [14:00<00:00,  1.07s/it]
100%|██████████| 157/157 [00:13<00:00, 11.67it/s]
 70%|███████   | 14/20 [23:58<31:19, 313.19s/it]

Epoch: [14/20], Accuracy: 0.5354


100%|██████████| 782/782 [00:18<00:00, 42.32it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.37it/s] 
 75%|███████▌  | 15/20 [24:29<19:00, 228.18s/it]

Epoch: [15/20], Accuracy: 0.5335


100%|██████████| 782/782 [00:18<00:00, 41.98it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.57it/s]
 80%|████████  | 16/20 [25:00<11:15, 168.87s/it]

Epoch: [16/20], Accuracy: 0.5439


100%|██████████| 782/782 [00:18<00:00, 42.59it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.50it/s]
 85%|████████▌ | 17/20 [25:31<06:22, 127.39s/it]

Epoch: [17/20], Accuracy: 0.5424


100%|██████████| 782/782 [00:18<00:00, 42.21it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.42it/s] 
 90%|█████████ | 18/20 [26:02<03:16, 98.48s/it] 

Epoch: [18/20], Accuracy: 0.5410


100%|██████████| 782/782 [00:18<00:00, 42.66it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.46it/s]
 95%|█████████▌| 19/20 [26:33<01:18, 78.20s/it]

Epoch: [19/20], Accuracy: 0.5327


100%|██████████| 782/782 [00:18<00:00, 41.43it/s] 
100%|██████████| 157/157 [00:12<00:00, 12.45it/s]
100%|██████████| 20/20 [27:05<00:00, 81.26s/it]

Epoch: [20/20], Accuracy: 0.5475
--- 1625.2944419384003 seconds ---





# Reweight Model Training

In [34]:
# model = get_cifar10_model()
model = LeNet()
model = model.to(device)

In [6]:
# Load CIFAR10 datasets
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

split_ratio = 0.95 # 80% for test, 20% for validation
n_samples = len(test_dataset)
n_test = int(n_samples * split_ratio)
n_val = n_samples - n_test
test_dataset, val_dataset = random_split(test_dataset, [n_test, n_val])


# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


### Random Sampler for Sampling Validation Data

In [16]:
class RandomSubsetSampler(torch.utils.data.Sampler):
    def __init__(self, dataset, subset_size):
        self.dataset = dataset
        self.subset_size = subset_size

    def __iter__(self):
        indices = random.sample(range(len(self.dataset)), self.subset_size)
        return iter(indices)

    def __len__(self):
        return self.subset_size

subset_sampler = RandomSubsetSampler(val_dataset, 64)
subset_dataloader = DataLoader(val_dataset, sampler=subset_sampler)

500


### Meta Baseline Trainloop

In [39]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

# Train the model
epochs = 20
lr = 0.001
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop
    for images, labels in tqdm(train_dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        # meta_net = get_cifar10_model()
        meta_net = LeNet()
        meta_net.load_state_dict(model.state_dict())

        meta_net = meta_net.to(device)

        optimizer_meta = torch.optim.Adam(meta_net.parameters())

        meta_net.train()
        
        y_f_hat = meta_net(images)
        cost = loss_fn_meta(y_f_hat, labels)
        eps = torch.zeros(cost.size(), requires_grad=True).to(device)
        l_f_meta = torch.sum(cost*eps)

        # meta_net.zero_grad()
        optimizer_meta.zero_grad()
        eps.retain_grad()
        l_f_meta.backward()
        optimizer_meta.step()

        meta_net.eval()

        # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
        # meta_net.update_params(lr, source_params=grads)
        
        val_images, val_labels = next(iter(subset_dataloader))
        # val_images, val_labels = next(iter(val_dataloader))
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)

        y_g_hat = meta_net(val_images)
        l_g_meta = loss_fn(y_g_hat, val_labels)

        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
        # print(grad_eps)

        with torch.no_grad():
            optimizer_meta.zero_grad()
            l_g_meta.backward()
            grad_eps = eps.grad
        
        # print(grad_eps)
        w_tilde = torch.clamp(grad_eps,min=0)
        # w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde
        
        # print(w)
        # break
        # Forward Pass
        outputs = model(images)
        loss = loss_fn_meta(outputs, labels)
        loss = torch.sum(loss*w)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")

print("--- %s seconds ---" % (time.time() - start_time))

100%|██████████| 782/782 [00:38<00:00, 20.54it/s]
100%|██████████| 149/149 [00:12<00:00, 11.65it/s] 
  5%|▌         | 1/20 [00:50<16:06, 50.86s/it]

Epoch: [1/20], Accuracy: 0.3925


100%|██████████| 782/782 [00:37<00:00, 20.92it/s]
100%|██████████| 149/149 [00:12<00:00, 11.79it/s] 
 10%|█         | 2/20 [01:40<15:06, 50.37s/it]

Epoch: [2/20], Accuracy: 0.4705


100%|██████████| 782/782 [00:37<00:00, 20.74it/s]
100%|██████████| 149/149 [00:12<00:00, 11.61it/s] 
 15%|█▌        | 3/20 [02:31<14:17, 50.45s/it]

Epoch: [3/20], Accuracy: 0.4973


100%|██████████| 782/782 [00:38<00:00, 20.46it/s]
100%|██████████| 149/149 [00:12<00:00, 11.67it/s] 
 20%|██        | 4/20 [03:22<13:30, 50.67s/it]

Epoch: [4/20], Accuracy: 0.5048


100%|██████████| 782/782 [00:38<00:00, 20.40it/s]
100%|██████████| 149/149 [00:12<00:00, 11.72it/s] 
 25%|██▌       | 5/20 [04:13<12:42, 50.80s/it]

Epoch: [5/20], Accuracy: 0.5098


100%|██████████| 782/782 [00:37<00:00, 20.67it/s]
100%|██████████| 149/149 [00:12<00:00, 11.74it/s] 
 30%|███       | 6/20 [05:04<11:49, 50.71s/it]

Epoch: [6/20], Accuracy: 0.5087


100%|██████████| 782/782 [00:37<00:00, 20.87it/s]
100%|██████████| 149/149 [00:12<00:00, 11.64it/s] 
 35%|███▌      | 7/20 [05:54<10:57, 50.57s/it]

Epoch: [7/20], Accuracy: 0.5324


100%|██████████| 782/782 [00:37<00:00, 20.86it/s]
100%|██████████| 149/149 [01:07<00:00,  2.20it/s] 
 40%|████      | 8/20 [07:39<13:35, 67.97s/it]

Epoch: [8/20], Accuracy: 0.5362


100%|██████████| 782/782 [00:38<00:00, 20.45it/s]
100%|██████████| 149/149 [00:12<00:00, 11.63it/s]
 45%|████▌     | 9/20 [08:30<11:29, 62.69s/it]

Epoch: [9/20], Accuracy: 0.5518


100%|██████████| 782/782 [00:37<00:00, 20.60it/s]
100%|██████████| 149/149 [00:12<00:00, 11.52it/s]
 50%|█████     | 10/20 [09:21<09:50, 59.05s/it]

Epoch: [10/20], Accuracy: 0.5537


100%|██████████| 782/782 [00:38<00:00, 20.37it/s]
100%|██████████| 149/149 [00:12<00:00, 11.64it/s] 
 55%|█████▌    | 11/20 [10:12<08:29, 56.64s/it]

Epoch: [11/20], Accuracy: 0.5647


100%|██████████| 782/782 [00:38<00:00, 20.19it/s]
100%|██████████| 149/149 [00:12<00:00, 11.50it/s]
 60%|██████    | 12/20 [11:04<07:21, 55.14s/it]

Epoch: [12/20], Accuracy: 0.5621


100%|██████████| 782/782 [00:38<00:00, 20.35it/s]
100%|██████████| 149/149 [00:12<00:00, 11.64it/s] 
 65%|██████▌   | 13/20 [11:55<06:17, 53.96s/it]

Epoch: [13/20], Accuracy: 0.5765


100%|██████████| 782/782 [00:38<00:00, 20.47it/s]
100%|██████████| 149/149 [00:12<00:00, 11.53it/s]
 70%|███████   | 14/20 [12:46<05:18, 53.10s/it]

Epoch: [14/20], Accuracy: 0.5733


100%|██████████| 782/782 [00:38<00:00, 20.42it/s]
100%|██████████| 149/149 [00:12<00:00, 11.64it/s] 
 75%|███████▌  | 15/20 [13:37<04:22, 52.50s/it]

Epoch: [15/20], Accuracy: 0.5783


100%|██████████| 782/782 [00:38<00:00, 20.25it/s]
100%|██████████| 149/149 [00:13<00:00, 11.30it/s]
 80%|████████  | 16/20 [14:29<03:29, 52.29s/it]

Epoch: [16/20], Accuracy: 0.5848


100%|██████████| 782/782 [00:38<00:00, 20.25it/s]
100%|██████████| 149/149 [00:12<00:00, 11.51it/s]
 85%|████████▌ | 17/20 [15:21<02:36, 52.07s/it]

Epoch: [17/20], Accuracy: 0.5947


100%|██████████| 782/782 [00:38<00:00, 20.36it/s]
100%|██████████| 149/149 [00:12<00:00, 11.62it/s] 
 90%|█████████ | 18/20 [16:12<01:43, 51.82s/it]

Epoch: [18/20], Accuracy: 0.5984


100%|██████████| 782/782 [00:37<00:00, 20.98it/s]
100%|██████████| 149/149 [00:12<00:00, 11.61it/s] 
 95%|█████████▌| 19/20 [17:02<00:51, 51.32s/it]

Epoch: [19/20], Accuracy: 0.5919


100%|██████████| 782/782 [00:37<00:00, 21.02it/s]
100%|██████████| 149/149 [00:12<00:00, 11.77it/s] 
100%|██████████| 20/20 [17:52<00:00, 53.62s/it]

Epoch: [20/20], Accuracy: 0.6049
--- 1072.4640719890594 seconds ---





# Meta-Milo Setup

### Load Data

In [44]:
num_classes = 10
class_data = []
for i in range(num_classes):
    with open(f"milo-base/class-data/class_{i}.pkl", "rb") as f:
        S = pickle.load(f)
        class_data.append(S)

In [45]:
num_sets = len(class_data[0])
data = []
for i in range(num_sets):
    S = []
    for j in range(num_classes):
        S.extend(class_data[j][i])
    data.append(S)

### Define Dataloader

In [42]:
class SubDataset(Dataset):
    def __init__(self, indices, dataset):
        self.indices = indices
        self.dataset = dataset

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        data_point = self.dataset[index]
        return data_point

### Model Training loop

In [46]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loss_fn_meta = nn.CrossEntropyLoss(reduction='none')

# Train the model
epochs = 20
lr = 0.001
model.train()
start_time = time.time()
for epoch in tqdm(range(epochs)):
    # Train loop
    sub_dataset = SubDataset(indices=data[epoch], dataset=train_dataset)
    train_dataloader = DataLoader(sub_dataset, batch_size=64, shuffle=True, num_workers=2)
    
    for images, labels in tqdm(train_dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        # meta_net = get_cifar10_model()
        meta_net = LeNet()
        meta_net.load_state_dict(model.state_dict())

        meta_net = meta_net.to(device)

        optimizer_meta = torch.optim.Adam(meta_net.parameters())

        meta_net.train()
        
        y_f_hat = meta_net(images)
        cost = loss_fn_meta(y_f_hat, labels)
        eps = torch.zeros(cost.size(), requires_grad=True).to(device)
        l_f_meta = torch.sum(cost*eps)

        # meta_net.zero_grad()
        optimizer_meta.zero_grad()
        eps.retain_grad()
        l_f_meta.backward()
        optimizer_meta.step()

        meta_net.eval()

        # grads = torch.autograd.grad(l_f_meta, (meta_net.parameters()), create_graph=True)
        # meta_net.update_params(lr, source_params=grads)
        
        val_images, val_labels = next(iter(subset_dataloader))
        # val_images, val_labels = next(iter(val_dataloader))
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)

        y_g_hat = meta_net(val_images)
        l_g_meta = loss_fn(y_g_hat, val_labels)

        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        # grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True, allow_unused=True)[0]
        # print(grad_eps)

        with torch.no_grad():
            optimizer_meta.zero_grad()
            l_g_meta.backward()
            grad_eps = eps.grad
        
        # print(grad_eps)
        w_tilde = torch.clamp(grad_eps,min=0)
        # w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde
        
        # print(w)
        # break
        # Forward Pass
        outputs = model(images)
        loss = loss_fn_meta(outputs, labels)
        loss = torch.sum(loss*w)
        
        # Backward pass and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Epoch: [{epoch+1}/{epochs}], Accuracy: {accuracy:.4f}")

print("--- %s seconds ---" % (time.time() - start_time))

  0%|          | 0/20 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/yuktabagdi/miniconda3/envs/feature-env/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/yuktabagdi/miniconda3/envs/feature-env/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'SubDataset' on <module '__main__' (built-in)>
  0%|          | 0/79 [00:21<?, ?it/s]
  0%|          | 0/20 [00:21<?, ?it/s]


KeyboardInterrupt: 