In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import lightning as L
import matplotlib.pyplot as plt
from transphorm.model_components.data_objects import SyntheticFPDataModule
from transphorm.model_components.model_modules import VanillaAutoEncoder, Encoder, Decoder
from pathlib import Path
import seaborn as sns

In [3]:


class CnnEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv2D_layers =nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=12, kernel_size=2, stride = 4),
            nn.ReLU(),
            nn.Dropout(0.5), 
            nn.Flatten(start_dim = -3)
        ) 
        self.linear_layers = nn.Sequential(
            nn.Linear(3000, 512), 
            nn.ReLU(), 
            nn.Linear(512, 256), 
            nn.ReLU(), 
            nn.Linear(256, 128), 
        )
    def forward(self, x):
        conv = self.conv2D_layers(x)
        linear = self.linear_layers(conv)
        return linear

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_layers = nn.Sequential(
            nn.Linear(128, 128), 
            nn.ReLU(), 
            nn.Linear(128, 128), 
            nn.Linear(128, 512), 
            nn.ReLU(),
            nn.Linear(512, 1000)
            )
    def forward(self, x):
        x = self.linear_layers(x)
        return x
    
class AutoEncoder2D(L.LightningModule):
    def __init__(self, encoder, decoder, optimizer):
        super().__init__()
        self.encoder = encoder()
        self.decoder = decoder()
        self.optimizer = optimizer
        self.loss_fn = nn.MSELoss()
        
        self.save_hyperparameters(ignore = ["enocder", "decoder"])
        
    def forward(self, inputs):
        x = self.encoder(inputs)
        x_recon = self.decoder(x)
        return x_recon
    
    def configure_optimizers(self):
        return self.optimizer(self.parameters())
    
    def _common_step(self, batch, batch_idx):
        X = batch[0]
        signal_truth = batch[0][0]
        encoded = self.encoder(X)
        
        x_hat = self.decoder(encoded)
        
        loss = self.loss_fn(x_hat, signal_truth)
        return loss
        

    def training_step(self,batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss
    
    def validation_step(self,batch, batch_idx):
        loss = self._common_step(batch, batch_idx) 
        return loss
    

In [4]:
data_module = SyntheticFPDataModule(batch_size=10, num_workers=1000)
data_module.prepare_data()
data_module.setup("train")

auto_encoder = AutoEncoder2D(
    encoder = CnnEncoder, 
    decoder=Decoder, 
    optimizer=torch.optim.Adam
)
trainer = L.Trainer(max_epochs=10)
trainer.fit(auto_encoder, data_module)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/mds8301/anaconda3/envs/transphorm/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name    | Type       | Params
---------------------------------------
0 | encoder | CnnEncoder | 1.7 M 
1 | decoder | Decoder    | 612 K 
2 | loss_fn | MSELoss    | 0     
---------------------------------------
2.3 M     Trainable params
0         Non-trainable params
2.3 M     Total para

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



In [14]:
load = data_module.val_dataloader()
for i in load:
    for j in i:
        print(j.shape)

torch.Size([20, 2, 1000])
torch.Size([20])
