In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import numpy as np
from pyhessian import hessian
from tqdm import tqdm
import pandas as pd
import os
import matplotlib.pyplot as plt

from train_mlp import muMLPTab9

device = "cuda"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
def get_cifar(batch_size=128, num_classes=10, MSE=False, on_gpu=False, device=None):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    train_ds = datasets.CIFAR10(root='/tmp', train=True, download=True, transform=transform)

    torch.manual_seed(0)
    np.random.seed(0)
    mask = np.isin(train_ds.targets, np.arange(num_classes))
    indices = np.arange(0, len(train_ds))[mask]

    if on_gpu:
        assert device is not None, "Please provide a device="
        X, y = [], []
        for i in tqdm(range(len(indices))):
            x, y_ = train_ds[i]
            X.append(x)
            y.append(y_)
        X = torch.stack(X)
        y = torch.tensor(y)
        train_ds = torch.utils.data.TensorDataset(X.to(device), y.to(device))
        train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        print(f"Estimated size of the dataset in MB: {(X.numel() * X.element_size() / 1024 / 1024)+(y.numel() * y.element_size() / 1024 / 1024):.2f}")
    else:
        train_ds = torch.utils.data.Subset(train_ds, indices)
        train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0)

    return train_dl, train_ds


from torch.utils.data import TensorDataset, DataLoader

def get_cifar(batch_size=128, num_classes=10, MSE=False, on_gpu=False, device=None):
    # Load raw dataset without transforms just to get targets
    raw_ds = datasets.CIFAR10(root='/tmp', train=True, download=True)
    targets = np.array(raw_ds.targets)
    mask = np.isin(targets, np.arange(num_classes))
    indices = np.where(mask)[0]

    # Now reload with transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    train_ds = datasets.CIFAR10(root='/tmp', train=True, download=False, transform=transform)

    # Apply mask and extract transformed data
    X, y = [], []
    for i in tqdm(indices):
        x, y_ = train_ds[i]
        X.append(x)
        y.append(y_)
    X = torch.stack(X)
    y = torch.tensor(y)

    # Optional: one-hot encoding
    if MSE:
        y = F.one_hot(y, num_classes=num_classes).float()

    # Move to GPU if needed
    if on_gpu:
        assert device is not None, "Please provide a device="
        X = X.to(device)
        y = y.to(device)

    # Create dataset and loader
    tensor_ds = TensorDataset(X, y)
    train_dl = DataLoader(tensor_ds, batch_size=batch_size, shuffle=True, pin_memory=not on_gpu)

    # Print memory usage
    if on_gpu:
        print(f"Estimated size of the dataset in MB: {(X.numel() * X.element_size() + y.numel() * y.element_size()) / 1024 / 1024:.2f}")

    return train_dl, tensor_ds


In [55]:
seed = 1
epochs = 2

# Tensors loaded on GPU per batch

In [56]:
dl, ds = get_cifar(batch_size=128, num_classes=2, MSE=False, on_gpu=False)
print(len(dl))

torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128).to(device)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 
for epoch in range(epochs):
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
    print(loss)

100%|██████████| 10000/10000 [00:01<00:00, 5923.22it/s]


79
torch.Size([128])
tensor(0.1325, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2127, device='cuda:0', grad_fn=<NllLossBackward0>)


In [57]:
torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128).to(device)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 
for epoch in range(epochs):
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
    print(loss)

torch.Size([128])
tensor(0.1325, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2127, device='cuda:0', grad_fn=<NllLossBackward0>)


# Tensors on GPU

In [58]:
dl, ds = get_cifar(batch_size=128, num_classes=2, MSE=False, on_gpu=True, device=device)
print(len(dl))

torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128).to(device)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 
for epoch in range(epochs):
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
    print(loss)

100%|██████████| 10000/10000 [00:01<00:00, 5010.60it/s]


Estimated size of the dataset in MB: 117.26
79
torch.Size([128])
tensor(0.1325, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2127, device='cuda:0', grad_fn=<NllLossBackward0>)


In [59]:
torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128).to(device)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 
for epoch in range(epochs):
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
    print(loss)

torch.Size([128])
tensor(0.1325, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2127, device='cuda:0', grad_fn=<NllLossBackward0>)


# MSE

In [None]:
dl, ds = get_cifar(batch_size=128, num_classes=2, MSE=True, on_gpu=True, device=device)
print(len(dl))

torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128).to(device)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 
for epoch in range(epochs):
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = F.MSE(out, y)
        loss.backward()
        optimizer.step()
    print(loss)

100%|██████████| 10000/10000 [00:01<00:00, 6967.96it/s]


Estimated size of the dataset in MB: 117.26
79
torch.Size([128, 2])


RuntimeError: 0D or 1D target tensor expected, multi-target not supported