In [1]:
import logging
from argparse import ArgumentParser

import torch
import torch.nn.functional as F
# import wandb
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader
from tqdm import trange

from experiments.data import INRDataset
from experiments.utils import (
    common_parser,
    count_parameters,
    get_device,
    set_logger,
    set_seed,
    str2bool,
)
from nn.models import DWSModelForClassification, MLPModelForClassification

from experiments.mnist.generate_data_splits import generate_splits
from experiments.mnist.compute_statistics import compute_stats

set_logger()

In [3]:
print(torch.cuda.current_device())  # The ID of the current GPU
print(torch.cuda.get_device_name(0))  # The name of the specified GPU
print(torch.cuda.device_count())  # The amount of GPUs that are accessible

0
Tesla P100-PCIE-12GB
1


In [2]:
torch.cuda.empty_cache()

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [30]:
#Loading inr data we created while mnist training
import os
current_working_directory = os.getcwd()
print(current_working_directory)
path = current_working_directory + "/notebooks/dataset/mnist_splits.json"
statistics_path = current_working_directory + "/notebooks/dataset/statistics.pth"
normalize = True
augmentation = True

batch_size = 32
num_workers = 1

/work/talisman/sgupta/DWSNets/equivariant-diffusion


In [31]:
train_set = INRDataset(
        path=path,
        split="train",
        normalize=normalize,
        augmentation=augmentation,
        statistics_path=statistics_path,
    )

val_set = INRDataset(
    path=path,
    split="val",
    normalize=normalize,
    statistics_path=statistics_path,
)

test_set = INRDataset(
    path=path,
    split="test",
    normalize=normalize,
    statistics_path=statistics_path,
)

train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
val_loader = torch.utils.data.DataLoader(
    dataset=val_set,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
)

logging.info(
    f"train size {len(train_set)}, "
    f"val size {len(val_set)}, "
    f"test size {len(test_set)}"
)

2024-08-02 13:00:25,460 - root - INFO - train size 55000, val size 5000, test size 10000


In [22]:
print(len(train_set))

55000


In [92]:
class Latent_AE_cnn(nn.Module):
    def __init__(
            self,
            in_dim,
            time_step=1000,
    ):
        super().__init__()

        # self.enc1 = nn.Sequential(nn.Conv1d(1, 10, 3, stride=1),nn.LeakyReLU(),nn.Conv1d(1, 10, 3, stride=1),)
        self.in_dim = in_dim
        self.fold_rate = 5
        self.kernal_size = 3
        self.channel_list = [4, 4, 4, 4]
        self.channel_list_dec = [8, 256, 256, 4]
        print(self.fold_rate)
        print(self.kernal_size)
        print(self.channel_list)
        print(self.channel_list_dec)
        self.real_input_dim = (
                int(in_dim / self.fold_rate ** 4 + 1) * self.fold_rate ** 4
        )

        self.enc1 = nn.Sequential(
            nn.InstanceNorm1d(self.real_input_dim),
            nn.Conv1d(1, self.channel_list[0], self.kernal_size, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim),
            nn.Conv1d(self.channel_list[0], self.channel_list[0], self.kernal_size, stride=self.fold_rate, padding=0),
            # nn.MaxPool1d(2),
        )
        self.enc2 = nn.Sequential(
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate),
            nn.Conv1d(self.channel_list[0], self.channel_list[0], self.kernal_size, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate),
            nn.Conv1d(self.channel_list[0], self.channel_list[1], self.kernal_size, stride=self.fold_rate, padding=0),
            # nn.MaxPool1d(2),
        )
        self.enc3 = nn.Sequential(
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 2),
            nn.Conv1d(self.channel_list[1], self.channel_list[1], self.kernal_size, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 2),
            nn.Conv1d(self.channel_list[1], self.channel_list[2], self.kernal_size, stride=self.fold_rate, padding=0),
            # nn.MaxPool1d(2),
        )
        self.enc4 = nn.Sequential(
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 3),
            nn.Conv1d(self.channel_list[2], self.channel_list[2], self.kernal_size, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 3),
            nn.Conv1d(self.channel_list[2], self.channel_list[3], self.kernal_size, stride=self.fold_rate, padding=0),
            nn.Tanh(),
        )

        self.dec1 = nn.Sequential(
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 4),
            nn.ConvTranspose1d(
                self.channel_list_dec[3], self.channel_list_dec[3], self.kernal_size, stride=self.fold_rate, padding=0
            ),
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 4),
            nn.Conv1d(self.channel_list_dec[3], self.channel_list_dec[2], self.kernal_size, stride=1,
                      padding=self.fold_rate - 1),
        )
        self.dec2 = nn.Sequential(
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 3),
            nn.ConvTranspose1d(
                self.channel_list_dec[2], self.channel_list_dec[2], self.kernal_size, stride=self.fold_rate, padding=0
            ),
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 3),
            nn.Conv1d(self.channel_list_dec[2], self.channel_list_dec[1], self.kernal_size, stride=1,
                      padding=self.fold_rate - 1),
        )
        self.dec3 = nn.Sequential(
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 2),
            nn.ConvTranspose1d(
                self.channel_list_dec[1], self.channel_list_dec[1], self.kernal_size, stride=self.fold_rate, padding=0
            ),
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate ** 2),
            nn.Conv1d(self.channel_list_dec[1], self.channel_list_dec[0], self.kernal_size, stride=1,
                      padding=self.fold_rate - 1),
        )
        self.dec4 = nn.Sequential(
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate),
            nn.ConvTranspose1d(
                self.channel_list_dec[0], self.channel_list_dec[0], self.kernal_size, stride=self.fold_rate, padding=0
            ),
            nn.LeakyReLU(),
            nn.InstanceNorm1d(self.real_input_dim // self.fold_rate),
            nn.Conv1d(self.channel_list_dec[0], 1, self.kernal_size, stride=1, padding=self.fold_rate),
        )

        # self.time_encode = nn.Embedding(time_step, self.real_input_dim)

    def forward(self, input):
        input_shape = input.shape
        if len(input.size()) == 2:
            input = input.view(input.size(0), 1, -1)

        input = torch.cat(
            [
                input,
                torch.zeros(input.shape[0], 1, (self.real_input_dim - self.in_dim)).to(
                    input.device
                ),
            ],
            dim=2,
        )
        emb_enc1 = self.enc1(input)
        emb_enc2 = self.enc2(emb_enc1)
        emb_enc3 = self.enc3(emb_enc2)
        emb_enc4 = self.enc4(emb_enc3)

        emb_enc4 = emb_enc4 + torch.randn_like(emb_enc4) * 0.1

        emb_dec1 = self.dec1(emb_enc4)
        emb_dec2 = self.dec2(emb_dec1)
        emb_dec3 = self.dec3(emb_dec2)
        emb_dec4 = self.dec4(emb_dec3)[:, :, :input_shape[-1]]

        return emb_dec4.reshape(input_shape)

    def Enc(self, input):
        if len(input.size()) == 2:
            input = input.view(input.size(0), 1, -1)

        input = torch.cat(
            [
                input,
                torch.zeros(input.shape[0], 1, (self.real_input_dim - self.in_dim)).to(input.device),
            ],
            dim=2,
        )
        emb_enc1 = self.enc1(input)
        emb_enc2 = self.enc2(emb_enc1)
        emb_enc3 = self.enc3(emb_enc2)
        emb_enc4 = self.enc4(emb_enc3)

        return emb_enc4

    def Dec(self, emb_enc4):
        emb_dec1 = self.dec1(emb_enc4)
        emb_dec2 = self.dec2(emb_dec1)
        emb_dec3 = self.dec3(emb_dec2)
        emb_dec4 = self.dec4(emb_dec3)[:, :, :self.in_dim]

        return emb_dec4

In [93]:
def reshape_input(batch):

    weights, biases = batch.weights, batch.biases

    #  Flatten weights and biases
    weights_flat = [w.view(w.size(0), -1) for w in weights]
    biases_flat = [b.view(b.size(0), -1) for b in biases]

    concatenated_layers = []

    for w, b in zip(weights_flat, biases_flat):
        concatenated_layers.append(w)
        concatenated_layers.append(b)

    # Concatenate all layers along the feature dimension
    final_concatenated = torch.cat(concatenated_layers, dim=1)
    return final_concatenated

In [94]:
from tqdm import trange
from torch.optim.lr_scheduler import StepLR

def train_model(model):
    learning_rate = 1e-3
    num_epochs = 100
    criterion =  nn.MSELoss()
    epoch_iter = trange(num_epochs)
    epoch_loss = -1
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in epoch_iter:
        total_loss = 0
        counter = 0
        for i, batch in enumerate(train_loader):
            model.train()
            optimizer.zero_grad()

            batch = batch.to(device)
            data = reshape_input(batch)
            out = model(data)
            loss = criterion(out, data)
            loss.backward()
            optimizer.step()

            epoch_iter.set_description(
                f"[{epoch} {i+1}], train loss: {loss.item():.3f}, epoch loss: {epoch_loss:.3f}"
            )
            
            total_loss = total_loss + loss.item()
            counter +=1
        epoch_loss = total_loss/counter
        if (epoch+1)%25 == 0:
            torch.save(model.state_dict(), f"Outputs/model4_epoch{epoch}_loss{epoch_loss}.pth")

In [None]:
import warnings
warnings.filterwarnings("ignore")

model = Latent_AE_cnn(
   in_dim = 1185
).to(device)
train_model(model)

  0%|          | 0/100 [00:00<?, ?it/s]

5
3
[4, 4, 4, 4]
[8, 256, 256, 4]


[11 1487], train loss: 0.996, epoch loss: 1.025:  11%|█         | 11/100 [54:23<5:41:34, 230.27s/it]