In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np
from sklearn.decomposition import PCA


In [4]:

class RBFNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(RBFNN, self).__init__()
        self.centers = nn.Parameter(torch.randn(hidden_dim, input_dim))  #Trainable centers
        self.beta = nn.Parameter(torch.ones(1) * 1.0)  #Shared beta for all RBFs
        self.linear = nn.Linear(hidden_dim, output_dim)

    def rbf(self, x, centers, beta):
        size = (x.size(0), centers.size(0), x.size(1))
        x = x.unsqueeze(1).expand(size)
        c = centers.unsqueeze(0).expand(size)
        return torch.exp(-beta * ((x - c) ** 2).sum(2))

    def forward(self, x):
        phi = self.rbf(x, self.centers, self.beta)
        out = self.linear(phi)
        return out

def extract_features_resnet18(loader, device):
    resnet = resnet18(pretrained=True)
    resnet = nn.Sequential(*list(resnet.children())[:-1])  #Remove last FC layer
    resnet.to(device)
    resnet.eval()

    features = []
    labels = []
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            feats = resnet(inputs).squeeze()
            features.append(feats.cpu())
            labels.append(targets)
    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0)
    return features, labels

def find_nearest_samples(features, centers):
    distances = torch.cdist(centers, features)
    nearest_indices = torch.argmin(distances, dim=1)
    return nearest_indices

def load_cifar10(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return trainset, trainloader, testset, testloader

def train_rbfnn(features, labels, input_dim, hidden_dim, output_dim, device, epochs=20, batch_size=1024):
    model = RBFNN(input_dim, hidden_dim, output_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    features, labels = features.to(device), labels.to(device)
    n_samples = features.size(0)

    for epoch in range(epochs):
        permutation = torch.randperm(n_samples)
        epoch_loss = 0.0
        for i in range(0, n_samples, batch_size):
            indices = permutation[i:i+batch_size]
            batch_features = features[indices]
            batch_labels = labels[indices]

            optimizer.zero_grad()
            outputs = model(batch_features)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}')
    return model

def train_resnet18_on_coreset(loader, device):
    model = resnet18(num_classes=10).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(20):  #Train for fewer epochs since data is small
        running_loss = 0.0
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(loader):.4f}')
    return model

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy on test set: {100 * correct / total:.2f}%')

In [5]:
def PCA_extract_features_resnet18(loader, device):
    resnet = resnet18(weights='IMAGENET1K_V1')
    resnet = nn.Sequential(*list(resnet.children())[:-1])  #Remove last FC layer
    resnet.to(device)
    resnet.eval()

    features = []
    labels = []
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            feats = resnet(inputs).squeeze()  #Shape (batch_size, 512)
            features.append(feats.cpu())
            labels.append(targets)
    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0)
    return features, labels

def PCA_train_resnet18_on_coreset(loader, device):
    model = resnet18(num_classes=10).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(20):  #Train for fewer epochs since data is small
        running_loss = 0.0
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(loader):.4f}')
    return model

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
trainset, trainloader, testset, testloader = load_cifar10()

features, labels = PCA_extract_features_resnet18(trainloader, device)
print(f'Extracted features shape: {features.shape}')

pca = PCA(n_components=100)
features_pca = pca.fit_transform(features.numpy())
features_pca = torch.tensor(features_pca).float()
input_dim = features_pca.shape[1]

hidden_dim = 1000  #Number of RBF centers (coreset size)
output_dim = 10  #CIFAR-10 classes
rbfnn = train_rbfnn(features_pca, labels, input_dim, hidden_dim, output_dim, device)

centers = rbfnn.centers.detach().cpu()

coreset_indices = find_nearest_samples(features_pca, centers)
print(f'Coreset size: {len(coreset_indices)}')

coreset_data = torch.utils.data.Subset(trainset, coreset_indices)
coreset_loader = torch.utils.data.DataLoader(coreset_data, batch_size=128, shuffle=True, num_workers=2)




model = PCA_train_resnet18_on_coreset(coreset_loader, device)
evaluate(model, testloader, device)


Files already downloaded and verified
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/gthampak/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:01<00:00, 32.7MB/s]


Extracted features shape: torch.Size([50000, 512])
Epoch 1/20, Loss: 112.8412
Epoch 2/20, Loss: 112.8441
Epoch 3/20, Loss: 112.8425
Epoch 4/20, Loss: 112.8391
Epoch 5/20, Loss: 112.8382
Epoch 6/20, Loss: 112.8469
Epoch 7/20, Loss: 112.8394
Epoch 8/20, Loss: 112.8375
Epoch 9/20, Loss: 112.8409
Epoch 10/20, Loss: 112.8427
Epoch 11/20, Loss: 112.8421
Epoch 12/20, Loss: 112.8390
Epoch 13/20, Loss: 112.8425
Epoch 14/20, Loss: 112.8386
Epoch 15/20, Loss: 112.8380
Epoch 16/20, Loss: 112.8413
Epoch 17/20, Loss: 112.8373
Epoch 18/20, Loss: 112.8411
Epoch 19/20, Loss: 112.8405
Epoch 20/20, Loss: 112.8415
Coreset size: 1000
Epoch 1, Loss: 4.9960
Epoch 2, Loss: 3.2210
Epoch 3, Loss: 2.0231
Epoch 4, Loss: 2.3333
Epoch 5, Loss: 2.4052
Epoch 6, Loss: 2.2516
Epoch 7, Loss: 1.3393
Epoch 8, Loss: 1.5307
Epoch 9, Loss: 2.1091
Epoch 10, Loss: 1.4859
Epoch 11, Loss: 1.3640
Epoch 12, Loss: 1.0276
Epoch 13, Loss: 0.8966
Epoch 14, Loss: 1.0287
Epoch 15, Loss: 0.8007
Epoch 16, Loss: 0.6996
Epoch 17, Loss: 0.59