In [1]:
import os
import time

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


import matplotlib.pyplot as plt
from PIL import Image

from torchvision import datasets, transforms
from numpy.linalg import svd
from scipy.linalg import subspace_angles

from sklearn.decomposition import PCA
from numpy import linalg as LA
from scipy.linalg import sqrtm



if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

In [2]:
##########################
### SETTINGS
##########################

# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.001
BATCH_SIZE = 128
WORKERS = 2
NUM_EPOCHS = 40
C = 0 #corruption level
save_dir = f'MNIST_labelnoise{C}'

# Architecture
NUM_FEATURES = 28*28
NUM_CLASSES = 10

# Other
DEVICE = "cuda:0"
GRAYSCALE = True




In [3]:
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print(f'Created directory: {save_dir}')


Created directory: MNIST_labelnoise0


In [4]:

def get_mnist_loaders(batch_size=64, workers=4, corrupt=0.0, seed=42):
    """
    Prepare DataLoader for MNIST dataset with optional label corruption.
    
    Parameters:
        batch_size (int): Batch size for training and testing.
        workers (int): Number of data loading workers.
        corrupt (float): Corruption level (0 to 1).
        seed (int): Random seed for reproducibility.
    
    Returns:
        train_loader, test_loader: DataLoaders for MNIST.
    """
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Normalize MNIST dataset with mean and std values
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Load the MNIST training dataset
    train_dataset = datasets.MNIST(
        root='data',
        train=True,
        transform=transform,
        download=True
    )

    # Corrupt labels if corruption level is specified
    if corrupt > 0:
        print(f'Applying {corrupt*100}% label corruption...')
        num_samples = len(train_dataset.targets)
        num_corrupt = int(num_samples * corrupt)
        corrupt_indices = np.random.choice(num_samples, num_corrupt, replace=False)
        
        for idx in corrupt_indices:
            original_label = train_dataset.targets[idx].item()
            new_label = np.random.choice([x for x in range(10) if x != original_label])
            train_dataset.targets[idx] = new_label

    # Load the MNIST test dataset (without corruption)
    test_dataset = datasets.MNIST(
        root='data',
        train=False,
        transform=transform
    )

    # Create DataLoaders
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers
    )
    
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers
    )
    
    return train_loader, test_loader





train_loader, test_loader = get_mnist_loaders(batch_size=BATCH_SIZE, workers=WORKERS, corrupt=C)




In [5]:
device = torch.device(DEVICE)
torch.manual_seed(0)

for epoch in range(2):

    for batch_idx, (x, y) in enumerate(train_loader):
        
        print('Epoch:', epoch+1, end='')
        print(' | Batch index:', batch_idx, end='')
        print(' | Batch size:', y.size()[0])
        
        x = x.to(device)
        y = y.to(device)
        break


Epoch: 1 | Batch index: 0 | Batch size: 128
Epoch: 2 | Batch index: 0 | Batch size: 128


In [6]:
##########################
### MODEL
##########################


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out




class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes, grayscale):
        self.inplanes = 64
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # because MNIST is already 1x1 here:
        # disable avg pooling
        #x = self.avgpool(x)
        
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas



def resnet18(num_classes):
    """Constructs a ResNet-18 model."""
    model = ResNet(block=BasicBlock, 
                   layers=[2, 2, 2, 2],
                   num_classes=NUM_CLASSES,
                   grayscale=GRAYSCALE)
    return model

In [7]:
def compute_accuracy(model, data_loader, device):
    correct_pred, num_examples = 0, 0
    for i, (features, targets) in enumerate(data_loader):
            
        features = features.to(device)
        targets = targets.to(device)

        logits, probas = model(features)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100

In [8]:
torch.manual_seed(RANDOM_SEED)

model = resnet18(NUM_CLASSES)
model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  

    

start_time = time.time()
for epoch in range(NUM_EPOCHS):
    
    model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)

        logits, probas = model(features)
        cost = F.cross_entropy(logits, targets)
        optimizer.zero_grad()
        
        cost.backward()
        
        
        optimizer.step()
        
        
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
                   %(epoch+1, NUM_EPOCHS, batch_idx, 
                     len(train_loader), cost))

        

    model.eval()
    with torch.set_grad_enabled(False): # save memory during inference
        print('Epoch: %03d/%03d | Train: %.3f%%' % (
              epoch+1, NUM_EPOCHS, 
              compute_accuracy(model, train_loader, device=DEVICE)))
        
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    param_filename = os.path.join(save_dir, str(epoch+1 ) + '.pt')
    torch.save(model.state_dict(), param_filename)
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

Epoch: 001/040 | Batch 0000/0469 | Cost: 2.6219
Epoch: 001/040 | Batch 0050/0469 | Cost: 0.1830
Epoch: 001/040 | Batch 0100/0469 | Cost: 0.1559
Epoch: 001/040 | Batch 0150/0469 | Cost: 0.0656
Epoch: 001/040 | Batch 0200/0469 | Cost: 0.1591
Epoch: 001/040 | Batch 0250/0469 | Cost: 0.1879
Epoch: 001/040 | Batch 0300/0469 | Cost: 0.0329
Epoch: 001/040 | Batch 0350/0469 | Cost: 0.1030
Epoch: 001/040 | Batch 0400/0469 | Cost: 0.1210
Epoch: 001/040 | Batch 0450/0469 | Cost: 0.1472
Epoch: 001/040 | Train: 98.172%
Time elapsed: 0.29 min
Epoch: 002/040 | Batch 0000/0469 | Cost: 0.0902
Epoch: 002/040 | Batch 0050/0469 | Cost: 0.0462
Epoch: 002/040 | Batch 0100/0469 | Cost: 0.0333
Epoch: 002/040 | Batch 0150/0469 | Cost: 0.0369
Epoch: 002/040 | Batch 0200/0469 | Cost: 0.0319
Epoch: 002/040 | Batch 0250/0469 | Cost: 0.0355
Epoch: 002/040 | Batch 0300/0469 | Cost: 0.0269
Epoch: 002/040 | Batch 0350/0469 | Cost: 0.0469
Epoch: 002/040 | Batch 0400/0469 | Cost: 0.0170
Epoch: 002/040 | Batch 0450/0469 

In [9]:
with torch.set_grad_enabled(False): # save memory during inference
    print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))

Test accuracy: 99.42%


In [10]:
def get_model_param_vec(model):
    """
    Return model parameters as a vector
    """
    vec = []
    for name,param in model.named_parameters():
        vec.append(param.detach().cpu().numpy().reshape(-1))
    return np.concatenate(vec, 0)

def get_model_grad_vec(model):
    # Return the model grad as a vector

    vec = []
    for name,param in model.named_parameters():
        vec.append(param.grad.detach().reshape(-1))
    return torch.cat(vec, 0)

def update_grad(model, grad_vec):
    idx = 0
    for name,param in model.named_parameters():
        arr_shape = param.grad.shape
        size = 1
        for i in range(len(list(arr_shape))):
            size *= arr_shape[i]
        param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape)
        idx += size

def update_param(model, param_vec):
    idx = 0
    for name,param in model.named_parameters():
        arr_shape = param.data.shape
        size = 1
        for i in range(len(list(arr_shape))):
            size *= arr_shape[i]
        param.data = param_vec[idx:idx+size].reshape(arr_shape)
        idx += size

In [11]:
# reduced dimensions
START_EPOCH = 0
END_EPOCH = 40
n_components = 20
NUM_EPOCHS_FINE_TUNE = 20

In [12]:
def get_model_grad_vec(model):
    """Return the gradient of the model as a flattened vector."""
    vec = []
    for param in model.parameters():
        if param.grad is not None:
            vec.append(param.grad.detach().reshape(-1))
    return torch.cat(vec, 0)

def update_grad(model, grad_vec):
    """Update the model gradients with a new flattened gradient vector."""
    idx = 0
    for param in model.parameters():
        if param.grad is not None:
            arr_shape = param.grad.shape
            size = param.grad.numel()
            param.grad.data = grad_vec[idx:idx + size].reshape(arr_shape)
            idx += size
def load_saved_parameters(save_dir, start_epoch, end_epoch):
    W = []
    for epoch in range(start_epoch, end_epoch):
        param_filename = os.path.join(save_dir, f'{epoch + 1}.pt')
        if os.path.exists(param_filename):
            model.load_state_dict(torch.load(param_filename))
            W.append(get_model_param_vec(model))
        else:
            print(f'File not found: {param_filename}')
    W = np.array(W)
    print(f'Loaded {len(W)} parameter vectors with shape: {W.shape}')
    return W


def get_model_param_vec(model):
    vec = []
    for name, param in model.named_parameters():
        vec.append(param.detach().cpu().numpy().reshape(-1))
    return np.concatenate(vec, 0)


In [13]:
#PSGD





W = load_saved_parameters(save_dir, START_EPOCH, END_EPOCH)
  # Obtain base variables through PCA
pca = PCA(n_components=n_components)
pca.fit_transform(W)
P = np.array(pca.components_)
print ('ratio:', pca.explained_variance_ratio_)
print ('P:', P.shape)
print(P.dtype)

P = torch.from_numpy(P).cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss().to(DEVICE)

# Training parameters

alpha = 0.1  # Learning rate for residual gradient



for epoch in range(NUM_EPOCHS_FINE_TUNE):
    model.train()
    
    for batch_idx, (features, targets) in enumerate(train_loader):
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)

        # Forward pass
        logits, probas = model(features)
        cost = criterion(logits, targets)

        # Backward pass to compute gradients
        optimizer.zero_grad()
        cost.backward()

        # Get the full gradient as a vector
        grad_vec = []
        for param in model.parameters():
            if param.grad is not None:
                grad_vec.append(param.grad.detach().reshape(-1))
        grad_vec = torch.cat(grad_vec, 0)

        # Project gradient to the reduced space
        gk = torch.mm(P, grad_vec.reshape(-1, 1))
        grad_proj = torch.mm(P.T, gk).reshape(-1)

        # Compute residual gradient
        grad_res = grad_vec - grad_proj

        # Update the model parameters using projected gradient
        idx = 0
        for param in model.parameters():
            if param.grad is not None:
                size = param.grad.numel()
                param.grad.data = grad_proj[idx:idx + size].reshape(param.grad.shape)
                idx += size
        optimizer.step()

        # Update model with residual gradient using a smaller learning rate (alpha)
        idx = 0
        for param in model.parameters():
            if param.grad is not None:
                size = param.grad.numel()
                param.grad.data = alpha * grad_res[idx:idx + size].reshape(param.grad.shape)
                idx += size
        optimizer.step()

        # Logging every 50 batches
        #if batch_idx % 50 == 0:
        #    print(f'Epoch: {epoch + 1}/{NUM_EPOCHS_FINE_TUNE} | Batch {batch_idx}/{len(train_loader)} | Cost: {cost:.4f}')

    # Evaluate the model after each epoch
    model.eval()
    train_acc = compute_accuracy(model, train_loader, DEVICE)
    test_acc = compute_accuracy(model, test_loader, DEVICE)
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS_FINE_TUNE} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%')

   

  model.load_state_dict(torch.load(param_filename))


Loaded 40 parameter vectors with shape: (40, 11175370)
ratio: [0.810835   0.07315665 0.03047929 0.01720474 0.01142453 0.00814474
 0.00605126 0.00476951 0.00374862 0.00328506 0.00290904 0.00254596
 0.0022525  0.00207408 0.00190761 0.0017146  0.00157571 0.00143144
 0.00136058 0.0011949 ]
P: (20, 11175370)
float32
Epoch 1/20 | Train Acc: 99.98% | Test Acc: 99.43%
Epoch 2/20 | Train Acc: 99.99% | Test Acc: 99.49%
Epoch 3/20 | Train Acc: 99.99% | Test Acc: 99.44%
Epoch 4/20 | Train Acc: 99.99% | Test Acc: 99.47%
Epoch 5/20 | Train Acc: 99.99% | Test Acc: 99.49%
Epoch 6/20 | Train Acc: 99.99% | Test Acc: 99.47%
Epoch 7/20 | Train Acc: 99.99% | Test Acc: 99.50%
Epoch 8/20 | Train Acc: 99.99% | Test Acc: 99.49%
Epoch 9/20 | Train Acc: 100.00% | Test Acc: 99.52%
Epoch 10/20 | Train Acc: 100.00% | Test Acc: 99.52%
Epoch 11/20 | Train Acc: 99.99% | Test Acc: 99.49%
Epoch 12/20 | Train Acc: 100.00% | Test Acc: 99.49%
Epoch 13/20 | Train Acc: 100.00% | Test Acc: 99.52%
Epoch 14/20 | Train Acc: 100.

In [14]:
#TME


def m_estimator(X):
    N, D = X.shape
    initcov = np.eye(D)  
    oldcov = initcov - 1
    cov = initcov
    iter_count = 1
    eps = 1e-10  

    while np.linalg.norm(oldcov - cov, 'fro') > 1e-12 and iter_count < 1000:
        temp = X @ np.linalg.inv(cov + eps * np.eye(D))  
        d = np.sum(temp * np.conjugate(X), axis=1)  
        oldcov = cov

       
        temp = (np.real(d) + eps * np.ones(N))**(-1)  

        
        temp_matrix = np.diag(temp)  
        cov = (X.T @ temp_matrix @ X) / (N * D)  
        cov = cov / np.trace(cov)  
        iter_count += 1  

    return cov



W = load_saved_parameters(save_dir, START_EPOCH, END_EPOCH)
  # Obtain base variables through PCA
V = np.dot(W,W.T)
W_hat = sqrtm(V)
   
Cov = m_estimator(W_hat)


pca = PCA(n_components=n_components)
pca.fit_transform(Cov)
U = np.array(pca.components_)
print('U:',U.shape)
P =  (W.T) @ LA.inv(W_hat) @ (U.T)
    
print ('P:', P.shape)
P = P.T
P = P.astype(np.float32)

P = torch.from_numpy(P).cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss().to(DEVICE)

# Training parameters

alpha = 0.1  # Learning rate for residual gradient



for epoch in range(NUM_EPOCHS_FINE_TUNE):
    model.train()
    
    for batch_idx, (features, targets) in enumerate(train_loader):
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)

        # Forward pass
        logits, probas = model(features)
        cost = criterion(logits, targets)

        # Backward pass to compute gradients
        optimizer.zero_grad()
        cost.backward()

        # Get the full gradient as a vector
        grad_vec = []
        for param in model.parameters():
            if param.grad is not None:
                grad_vec.append(param.grad.detach().reshape(-1))
        grad_vec = torch.cat(grad_vec, 0)

        # Project gradient to the reduced space
        gk = torch.mm(P, grad_vec.reshape(-1, 1))
        grad_proj = torch.mm(P.T, gk).reshape(-1)

        # Compute residual gradient
        grad_res = grad_vec - grad_proj

        # Update the model parameters using projected gradient
        idx = 0
        for param in model.parameters():
            if param.grad is not None:
                size = param.grad.numel()
                param.grad.data = grad_proj[idx:idx + size].reshape(param.grad.shape)
                idx += size
        optimizer.step()

        # Update model with residual gradient using a smaller learning rate (alpha)
        idx = 0
        for param in model.parameters():
            if param.grad is not None:
                size = param.grad.numel()
                param.grad.data = alpha * grad_res[idx:idx + size].reshape(param.grad.shape)
                idx += size
        optimizer.step()

        # Logging every 50 batches
        #if batch_idx % 50 == 0:
        #    print(f'Epoch: {epoch + 1}/{NUM_EPOCHS_FINE_TUNE} | Batch {batch_idx}/{len(train_loader)} | Cost: {cost:.4f}')

    # Evaluate the model after each epoch
    model.eval()
    train_acc = compute_accuracy(model, train_loader, DEVICE)
    test_acc = compute_accuracy(model, test_loader, DEVICE)
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS_FINE_TUNE} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%')

   


  model.load_state_dict(torch.load(param_filename))


Loaded 40 parameter vectors with shape: (40, 11175370)
U: (20, 40)
P: (11175370, 20)
Epoch 1/20 | Train Acc: 99.98% | Test Acc: 99.46%
Epoch 2/20 | Train Acc: 99.98% | Test Acc: 99.43%
Epoch 3/20 | Train Acc: 99.99% | Test Acc: 99.47%
Epoch 4/20 | Train Acc: 99.99% | Test Acc: 99.47%
Epoch 5/20 | Train Acc: 99.99% | Test Acc: 99.46%
Epoch 6/20 | Train Acc: 99.99% | Test Acc: 99.52%
Epoch 7/20 | Train Acc: 99.99% | Test Acc: 99.51%
Epoch 8/20 | Train Acc: 99.99% | Test Acc: 99.50%
Epoch 9/20 | Train Acc: 99.99% | Test Acc: 99.47%
Epoch 10/20 | Train Acc: 100.00% | Test Acc: 99.53%
Epoch 11/20 | Train Acc: 100.00% | Test Acc: 99.48%
Epoch 12/20 | Train Acc: 100.00% | Test Acc: 99.54%
Epoch 13/20 | Train Acc: 100.00% | Test Acc: 99.50%
Epoch 14/20 | Train Acc: 100.00% | Test Acc: 99.47%
Epoch 15/20 | Train Acc: 100.00% | Test Acc: 99.52%
Epoch 16/20 | Train Acc: 100.00% | Test Acc: 99.53%
Epoch 17/20 | Train Acc: 100.00% | Test Acc: 99.50%
Epoch 18/20 | Train Acc: 100.00% | Test Acc: 99.4

In [15]:
#FMS

def FMS(X, dd):
    D, N = X.shape

    # Initial iteration count
    iter = 1

    # Perform SVD and initialize L
    U, _, _ = svd(X, full_matrices=False)
    L = U[:, :dd]

    # Set initial angle and tolerance
    ang = 1

    # Iterate until convergence or max iteration count
    while ang > 1e-12 and iter < 1000:
        Lold = L

        # Compute the residual projection
        temp = (np.eye(D) - L @ L.T) @ X
        w = np.sqrt(np.sum(temp**2, axis=0)) + 1e-10

        # Reweight and update XX
        XX = X @ np.diag(1.0 / w) @ X.T

        # Perform SVD again on the weighted matrix XX
        U, _, _ = svd(XX, full_matrices=False)
        L = U[:, :dd]

        # Compute the angle between new and old subspace
        ang = np.linalg.norm(subspace_angles(L, Lold))

        iter += 1

    return L

W = load_saved_parameters(save_dir, START_EPOCH, END_EPOCH)
  # Obtain base variables through PCA
V = np.dot(W,W.T)
W_hat = sqrtm(V)
n_components = n_components
U= FMS(W_hat,n_components)
    
P =  (W.T) @ LA.inv(W_hat) @ (U)
    
print ('P:', P.shape)
P = P.T
P = P.astype(np.float32)
P = torch.from_numpy(P).cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss().to(DEVICE)

# Training parameters

alpha = 0.1  # Learning rate for residual gradient



for epoch in range(NUM_EPOCHS_FINE_TUNE):
    model.train()
    
    for batch_idx, (features, targets) in enumerate(train_loader):
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)

        # Forward pass
        logits, probas = model(features)
        cost = criterion(logits, targets)

        # Backward pass to compute gradients
        optimizer.zero_grad()
        cost.backward()

        # Get the full gradient as a vector
        grad_vec = []
        for param in model.parameters():
            if param.grad is not None:
                grad_vec.append(param.grad.detach().reshape(-1))
        grad_vec = torch.cat(grad_vec, 0)

        # Project gradient to the reduced space
        gk = torch.mm(P, grad_vec.reshape(-1, 1))
        grad_proj = torch.mm(P.T, gk).reshape(-1)

        # Compute residual gradient
        grad_res = grad_vec - grad_proj

        # Update the model parameters using projected gradient
        idx = 0
        for param in model.parameters():
            if param.grad is not None:
                size = param.grad.numel()
                param.grad.data = grad_proj[idx:idx + size].reshape(param.grad.shape)
                idx += size
        optimizer.step()

        # Update model with residual gradient using a smaller learning rate (alpha)
        idx = 0
        for param in model.parameters():
            if param.grad is not None:
                size = param.grad.numel()
                param.grad.data = alpha * grad_res[idx:idx + size].reshape(param.grad.shape)
                idx += size
        optimizer.step()

        # Logging every 50 batches
        #if batch_idx % 50 == 0:
        #    print(f'Epoch: {epoch + 1}/{NUM_EPOCHS_FINE_TUNE} | Batch {batch_idx}/{len(train_loader)} | Cost: {cost:.4f}')

    # Evaluate the model after each epoch
    model.eval()
    train_acc = compute_accuracy(model, train_loader, DEVICE)
    test_acc = compute_accuracy(model, test_loader, DEVICE)
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS_FINE_TUNE} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%')

    


  model.load_state_dict(torch.load(param_filename))


Loaded 40 parameter vectors with shape: (40, 11175370)
P: (11175370, 20)
Epoch 1/20 | Train Acc: 99.98% | Test Acc: 99.44%
Epoch 2/20 | Train Acc: 99.98% | Test Acc: 99.46%
Epoch 3/20 | Train Acc: 99.99% | Test Acc: 99.51%
Epoch 4/20 | Train Acc: 99.99% | Test Acc: 99.44%
Epoch 5/20 | Train Acc: 99.99% | Test Acc: 99.49%
Epoch 6/20 | Train Acc: 99.99% | Test Acc: 99.45%
Epoch 7/20 | Train Acc: 99.99% | Test Acc: 99.47%
Epoch 8/20 | Train Acc: 99.99% | Test Acc: 99.47%
Epoch 9/20 | Train Acc: 100.00% | Test Acc: 99.48%
Epoch 10/20 | Train Acc: 100.00% | Test Acc: 99.45%
Epoch 11/20 | Train Acc: 100.00% | Test Acc: 99.51%
Epoch 12/20 | Train Acc: 100.00% | Test Acc: 99.48%
Epoch 13/20 | Train Acc: 100.00% | Test Acc: 99.47%
Epoch 14/20 | Train Acc: 100.00% | Test Acc: 99.52%
Epoch 15/20 | Train Acc: 100.00% | Test Acc: 99.53%
Epoch 16/20 | Train Acc: 100.00% | Test Acc: 99.50%
Epoch 17/20 | Train Acc: 100.00% | Test Acc: 99.50%
Epoch 18/20 | Train Acc: 100.00% | Test Acc: 99.49%
Epoch 19