In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import pickle
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [5]:
data_file = (
    "mnist_data/mnist.pkl"
)

with open(data_file, "rb") as f:
    data = pickle.load(f)

In [9]:
data['meta_train_x'][0]

array([[0.71428571, 0.96428571],
       [0.53571429, 0.64285714],
       [0.35714286, 0.32142857],
       [0.57142857, 0.14285714],
       [0.03571429, 0.53571429],
       [0.92857143, 0.32142857],
       [0.82142857, 0.03571429],
       [0.89285714, 0.75      ],
       [0.10714286, 0.32142857],
       [0.75      , 0.10714286],
       [0.10714286, 0.85714286],
       [0.46428571, 0.07142857],
       [0.85714286, 0.64285714],
       [0.28571429, 0.82142857],
       [0.71428571, 0.42857143],
       [0.28571429, 0.03571429],
       [0.03571429, 0.17857143],
       [0.5       , 0.75      ],
       [0.14285714, 0.5       ],
       [0.71428571, 0.92857143],
       [0.28571429, 0.53571429],
       [0.07142857, 0.32142857],
       [0.71428571, 0.67857143],
       [0.46428571, 0.71428571],
       [0.67857143, 0.14285714],
       [0.78571429, 0.89285714],
       [0.42857143, 0.78571429],
       [0.21428571, 0.32142857],
       [0.03571429, 0.46428571],
       [0.25      , 0.96428571],
       [0.

In [90]:
class FOMDataset(Dataset):
    def __init__(self, csv_file):
        # Load the CSV file
        self.data = pd.read_csv(csv_file)

        # Extract features (thickness, wavelength) and target (fom)
        self.X = self.data[['thickness', 'wavelength']].values
        self.y = self.data['fom'].values

    def __len__(self):
        # Return the total number of samples
        return len(self.data)

    def __getitem__(self, idx):
        # Get the input features and target for a given index
        x = torch.tensor(self.X[idx], dtype=torch.float32)
        y = torch.tensor(self.y[idx], dtype=torch.float32).unsqueeze(0)
        return x, y

In [91]:
class SinusoidGenerator(Dataset):
    def __init__(self, train=True, few_k_shot=20):
        self.few_k_shot = few_k_shot
        data_file = (
            "sinusoidal_data/sinusoid_data/sinusoidal_train.pkl"
            if train
            else "sinusoidal_data/sinusoid_data/sinusoidal_test.pkl"
        )

        with open(data_file, "rb") as f:
            data = pickle.load(f)

        self.data = {
            "train_x": torch.tensor(data["x"][:, : self.few_k_shot, :]),
            "train_y": torch.tensor(data["y"][:, : self.few_k_shot, :]),
            "test_x": torch.tensor(data["x"][:, 20:, :]),
            "test_y": torch.tensor(data["y"][:, 20:, :]),
        }

        print(
            "load data: train_x",
            self.data["train_x"].shape,
            "val_x",
            self.data["test_x"].shape,
            "train_y",
            self.data["train_y"].shape,
            "val_y",
            self.data["test_y"].shape,
        )

        self.train = train
        self.dim_input = 1
        self.dim_output = 1

    def generate_batch(self, indx):
        context_x = self.data["train_x"][indx]
        context_y = self.data["train_y"][indx]
        target_x = self.data["test_x"][indx]
        target_y = self.data["test_y"][indx]

        if self.train:
            return context_x, context_y, target_x, target_y
        else:
            return torch.cat((context_x, target_x)), torch.cat((context_y, target_y))
    
    def __len__(self):
        if self.train:
            return self.data["train_x"].shape[0]
        else:
            return self.data["test_x"].shape[0]
    
    def __getitem__(self, idx):
        return self.generate_batch(idx)

In [92]:
sine_dataset_train = SinusoidGenerator(train=True, few_k_shot=20)
sine_dataset_test = SinusoidGenerator(train=False)

sine_dataloader_train = DataLoader(
    sine_dataset_train, batch_size=1, shuffle=True, num_workers=4
)
sine_dataloader_test = DataLoader(
    sine_dataset_test, batch_size=1, shuffle=False, num_workers=4
)

load data: train_x torch.Size([240000, 20, 1]) val_x torch.Size([240000, 10, 1]) train_y torch.Size([240000, 20, 1]) val_y torch.Size([240000, 10, 1])
load data: train_x torch.Size([100, 20, 1]) val_x torch.Size([100, 100, 1]) train_y torch.Size([100, 20, 1]) val_y torch.Size([100, 100, 1])


In [93]:
for t in sine_dataloader_train:
  print(t[0].shape, t[1].shape, t[2].shape, t[3].shape)
  break

for t in sine_dataloader_test:
  print(t[0].shape, t[1].shape)
  break

torch.Size([1, 20, 1]) torch.Size([1, 20, 1]) torch.Size([1, 10, 1]) torch.Size([1, 10, 1])
torch.Size([1, 120, 1]) torch.Size([1, 120, 1])


In [94]:
# Create the dataset and DataLoader
# csv_file = '/content/drive/MyDrive/IITP/Capstone II/fom.csv'  # Replace with the actual path to your CSV file
# fom_dataset = FOMDataset(csv_file)
# fom_dataloader = DataLoader(fom_dataset, batch_size=64, shuffle=True)

## Model definitions

In [95]:
# baseline for testing (similiar in paper)
class BasisFunctionLearner(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BasisFunctionLearner, self).__init__()
        # Define a fully connected network with ReLU activations
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, hidden_dim)
        self.fc6 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = self.relu(self.fc5(x))
        return self.fc6(x)  # This represents the basis functions, Φ(x)


In [96]:
# as per paper
class AttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionLayer, self).__init__()
        self.query = nn.Linear(input_dim, hidden_dim)
        self.key = nn.Linear(input_dim, hidden_dim)
        self.value = nn.Linear(input_dim, hidden_dim)
        self.ff_layer_0 = nn.Linear(hidden_dim, hidden_dim * 2)
        self.ff_layer_1 = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        dotp = torch.matmul(query, key.transpose(-2, -1)) / (query.size(-1) ** 0.5)
        attention_weights = F.softmax(dotp, dim=-1)
        weighted_sum = torch.matmul(attention_weights, value)
        x = weighted_sum + query  # Adding the query for residual connection
        x = F.layer_norm(x, x.size()[1:])

        dense_out_0 = F.relu(self.ff_layer_0(x))
        x = x + self.ff_layer_1(dense_out_0)
        x = F.layer_norm(x, x.size()[1:])

        return x

In [97]:
# as per paper

class WeightsGenerator(nn.Module):
    def __init__(self, basis_function_dim, hidden_dim=512, attention_layers=8):
        super(WeightsGenerator, self).__init__()
        self.top_attention_layer = AttentionLayer(basis_function_dim, hidden_dim)
        self.attention_layers = nn.ModuleList([AttentionLayer(hidden_dim, hidden_dim) for _ in range(attention_layers)])

        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()

        self.final_dense = nn.Linear(hidden_dim, basis_function_dim)

    def forward(self, inputs):
        x = inputs
        x = self.top_attention_layer(x)
        for attention_layer in self.attention_layers:
            x = attention_layer(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        final_weights = self.final_dense(x)
        return final_weights

In [98]:
class FewShotRegressionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, basis_function_dim, output_dim):
        super(FewShotRegressionModel, self).__init__()
        self.basis_learner = BasisFunctionLearner(input_dim, hidden_dim, basis_function_dim)
        self.weights_generator = WeightsGenerator(basis_function_dim)

    def forward(self, x):
        # Step 1: Generate basis functions Φ(x)
        basis_functions = self.basis_learner(x)

        # Step 2: Generate weights w
        weights = self.weights_generator(basis_functions)

        # Step 3: Compute final prediction z = Φ(x) * w
        return weights, torch.diag(torch.matmul(basis_functions, weights.T))

In [99]:
def custom_loss_function(predictions, targets, weights, l1_lambda=0.001, l2_lambda=0.0001):
    mse_loss = nn.MSELoss()(predictions, targets)
    l1_loss = l1_lambda * torch.norm(weights, p=1)
    l2_loss = l2_lambda * torch.norm(weights, p=2)
    return mse_loss + l1_loss + l2_loss


## Training loop

In [100]:
input_dim = 1  # For example, (x1, x2) as inputs would require 2 dim
hidden_dim = 512
basis_function_dim = 256
output_dim = 1  # For example, z as output
learning_rate = 0.005
num_epochs = 60000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = FewShotRegressionModel(
    input_dim, hidden_dim, basis_function_dim, output_dim
).to(device)

num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_trainable_params}")

Number of trainable parameters: 17989120


In [88]:
input_dim = 1  # For example, (x1, x2) as inputs would require 2 dim
hidden_dim = 512
basis_function_dim = 256
output_dim = 1  # For example, z as output
learning_rate = 0.005
num_epochs = 60000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = FewShotRegressionModel(input_dim, hidden_dim, basis_function_dim, output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

torch.autograd.set_detect_anomaly(True)
for epoch in range(num_epochs):
    model.train()

    for batch_x_train, batch_y_train, batch_x_val, batch_y_val in tqdm(
        sine_dataloader_train
    ):
        # move to device
        batch_x_train = batch_x_train.to(device, torch.float32).squeeze(0)
        batch_y_train = batch_y_train.to(device, torch.float32).squeeze(0)
        batch_x_val = batch_x_val.to(device, torch.float32).squeeze(0)
        batch_y_val = batch_y_val.to(device, torch.float32).squeeze(0)
        
        batch_size = batch_x_train.shape[0]

        loss = 0
        wts, predictions = model(batch_x_train)

        loss += custom_loss_function(predictions, batch_y_train, wts)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print loss every 10 epochs
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

 10%|█         | 24167/240000 [1:01:27<9:08:53,  6.55it/s] 


KeyboardInterrupt: 

In [23]:
# After training, you can use the model to make predictions on new data
x_new = torch.tensor([[29.84, 1099.5]]).to(device)
z_pred = model(x_new)
print(f'Predicted z for input (29.84, 1099.5): {z_pred[1].item()}')

Predicted z for input (29.84, 1099.5): 1381.579345703125
