In [1]:

import torch
from torch import nn
from icecream import ic
from torch.optim import Adam
from torchinfo import summary
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Subset
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset, random_split


# -------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(str(device))

N = 4  # number of images in whole picture
L_x = 5  # x and y pixel dims
L_y = 5
N_i = L_x * L_y
N_h = 256  # number of hidden units
N_o = 1  # number of output units


class IonImagesDataset(Dataset):
    def __init__(self, file_path):
        loaded_data_dict = torch.load(file_path)  # loading dataset into dataloader

        self.images = loaded_data_dict[
            "images"
        ]  # creating of 2 datasents of imgages and keys to them
        self.labels = loaded_data_dict["labels"]

    def __len__(self):
        return len(self.images)  # just return len function

    def __getitem__(self, idx):
        image_tensor = self.images[idx]  # Add a channel dimension
        label_tensor = self.labels[idx]  # Repeat the label for each ion position
        return image_tensor, label_tensor


file_path_pt = "binary\labels_and_images.pt"


dataset = IonImagesDataset(file_path_pt)
halfpi_dataset = IonImagesDataset(file_path_pt)

# Split the dataset into training and validation subsets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
print(
    f"Train size: {train_size}, Validation size: {val_size}"
)  # Print the sizes of subsets

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

print(
    f"Train dataset size: {len(train_dataset)}"
)  # Print the size of the train_dataset
print(
    f"Validation dataset size: {len(val_dataset)}"
)  # Print the size of the val_dataset

# Create DataLoaders for the training and validation datasets
batch_size = min(1000, len(train_dataset))  # or choose a smaller value
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=False
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True, drop_last=False
)

halfpi_loader = DataLoader(
    halfpi_dataset, batch_size=batch_size, shuffle=True, drop_last=False
)


class IndexDependentDense(nn.Module):
    def __init__(self, N, N_i, N_o, activation=nn.ReLU()):
        super().__init__()

        self.N = N
        self.N_i = N_i
        self.N_o = N_o
        self.activation = activation
        self.register_parameter(
            "W", nn.Parameter(torch.empty(self.N, self.N_i, self.N_o))
        )
        self.register_parameter("b", nn.Parameter(torch.empty(self.N, self.N_o)))

        self._reset_parameters()

        pass

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.W)
        nn.init.zeros_(self.b)

    def forward(self, x):
        y = torch.einsum("nij,...ni->...nj", self.W, x) + self.b
        if self.activation is not None:
            return self.activation(y)
        else:
            return y

    pass


# ---------------------------------------------------------------------------------------------


class Encoder(nn.Module):
    def __init__(self, N, N_i, N_o):
        super().__init__()

        self.N = N
        self.N_i = N_i
        self.N_o = N_o

        self.dense = IndexDependentDense(N, N_i, N_o, activation=nn.ReLU())
        pass

    def forward(self, x):
        y = self.dense(x)
        return y

    pass


# ---------------------------------------------------------------------------------------------


class Classifier(nn.Module):
    def __init__(self, N, N_i, N_o):
        super().__init__()

        self.N = N
        self.N_i = N_i
        self.N_o = N_o
        self.dense = IndexDependentDense(N, N_i, N_o, activation=None)
        pass

    def forward(self, x):
        y = self.dense(x)
        y = torch.sigmoid(y)  # Apply sigmoid activation here
        return y

    pass


# ---------------------------------------------------------------------------------------------


class SharedEncoder(nn.Module):
    def __init__(self, N, N_i, N_o):
        super().__init__()

        self.N = N
        self.N_i = N_i
        self.N_o = N_o

        self.dense = nn.Linear(N_i, N_o)
        pass

    def forward(self, x):
        y = self.dense(x)
        return y

    pass


# ---------------------------------------------------------------------------------------------


class MultiIonReadout(nn.Module):
    def __init__(self, encoder, shared_encoder, classifier):
        super().__init__()

        self.encoder = encoder
        self.shared_encoder = shared_encoder
        self.classifier = classifier

    def forward(self, x):
        y = x.reshape(*x.shape[:-2], -1).to(torch.float32)
        y1 = self.encoder(y)
        y2 = self.shared_encoder(y)
        y_concat = torch.cat([y1, y2], dim=-1)
        y = self.classifier(y_concat)
        return y

    def bceloss(self, X, y):
        return F.binary_cross_entropy(self(X), y)

    @staticmethod
    def _accuracy(y_pred, y_true):
        mod_y_pred = (y_pred > 0.5).to(torch.float32)
        accuracy = (y_true == mod_y_pred).to(dtype=torch.float32).mean()
        return accuracy * 100

    def accuracy(self, x, y):
        return self._accuracy(self(x), y)


# ---------------------------------------------------------------------------------------------

cpu
Train size: 8000, Validation size: 2000
Train dataset size: 8000
Validation dataset size: 2000


In [2]:
device = torch.device("cpu")
# model 
encoder = Encoder(N, N_i, N_h)
shared_encoder = SharedEncoder(N, N_i, N_h)
classifier = Classifier(N, N_h * 2, N_o)
model = MultiIonReadout(encoder, shared_encoder, classifier)


model = model.to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params2 = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)
print(pytorch_total_params2)

N_epochs = 100
lr = 1e-3
optimizer = Adam(model.parameters(), lr=lr)
schedule_params = {"factor": 1}
schedule = lr_scheduler.ConstantLR(optimizer, **schedule_params)
log_every = 1

# Training loop
for epoch in range(N_epochs):

    total_train_loss = 0
    for inputs, labels in train_loader:

        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        loss = model.bceloss(inputs, labels)
        loss.backward()
        optimizer.step()

    # Evaluation loop
    with torch.no_grad():
        total_loss = 0
        total_accuracy = 0
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            loss = model.bceloss(inputs, labels)
            accuracy = model.accuracy(inputs, labels)
            total_loss += loss.item()
            total_accuracy += accuracy.item()

        avg_loss = total_loss / len(val_loader)
        avg_accuracy = total_accuracy / len(val_loader)

    print(
        "\r Epoch {}/{}, Training Loss = {}, Val Loss = {}, Val Acc = {}".format(
            epoch + 1, N_epochs, loss.item(), avg_loss, avg_accuracy
        ),
        end="",
    )

# torch.save(model.state_dict(), "golden_WandB_n.pth")

35332
35332
 Epoch 100/100, Training Loss = 0.2069842368364334, Val Loss = 0.22685407102108002, Val Acc = 91.86249923706055