In [11]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F

In [12]:
class Encoder(nn.Module):
    """
        Encoder Class:
        - This class is responsible for encoding input data into a compressed representation using a combination of convolutional layers, pooling layers, a BiLSTM layer, and self-attention.
        
        Input:
            - x: Tensor of shape (B, 1, 100, 7), where B is the batch size, 1 is the number of input channels, and (100, 7) is the spatial dimension of the input.
        Output:
            - x: Tensor of shape (B, 64, 5), where 64 is the feature dimension, and 5 is the reduced spatial dimension after encoding.
    """
    def __init__(self):
        super().__init__()

        self.bn1 = nn.BatchNorm2d(1)  # Input shape: (B, 1, 100, 7)
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(3, 3), padding=1)  # Shape: (B, 16, 100, 7)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 1))  # Shape: (B, 16, 50, 7)
        
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(3, 3), padding=1)  # Shape: (B, 32, 50, 7)
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 1))  # Shape: (B, 32, 25, 7)
        
        self.conv3 = nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1)  # Shape: (B, 64, 25, 7)
        self.pool3 = nn.MaxPool2d(kernel_size=(5, 7))  # Shape: (B, 64, 5, 1)
        
        self.reshape = nn.Flatten(2)  # Shape: (B, 64, 5)
        self.bilstm = nn.LSTM(64, 32, batch_first=True, bidirectional=True)  # BiLSTM
        self.self_attention = nn.MultiheadAttention(embed_dim=64, num_heads=2)
    
    def forward(self, x):
        x = self.bn1(x)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool3(x)  # Shape: (B, 64, 5, 1)
        
        x = self.reshape(x)  # Shape: (B, 64, 5)
        x = x.transpose(1, 2)  # Change to (B, 5, 64)
        x, _ = self.bilstm(x)  # BiLSTM output shape: (B, 5, 64)
        x = x.transpose(1, 2)  # Back to (B, 64, 5)
        
        x = x.permute(2, 0, 1)  # Self-attention requires shape (T, B, C)
        x, _ = self.self_attention(x, x, x)
        x = x.permute(1, 2, 0)  # Shape back to (B, 64, 5)
        return x

class Decoder(nn.Module):
    """
        Decoder Class:
        - This class reconstructs input data from a compressed representation. It uses upsampling layers and convolutional layers to expand the spatial dimensions back to the original shape.

        Input:
            - x: Tensor of shape (B, 64, 5, 1), where B is the batch size, 64 is the feature dimension, and (5, 1) is the spatial dimension of the encoded input.
        Output:
            - x: Tensor of shape (B, 1, 100, 7), where 1 is the output channel, and (100, 7) is the reconstructed spatial dimension.
    """
    def __init__(self):

        super().__init__()
        # Decoder blocks: Reverse of the Encoder
        self.up1 = nn.Upsample(size=(25, 7), mode='bilinear')
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        self.up2 = nn.Upsample(size=(50, 7), mode='bilinear')
        self.conv5 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        
        self.up3 = nn.Upsample(size=(100, 7), mode='bilinear')
        self.conv6 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        
        self.conv7 = nn.Conv2d(16, 1, kernel_size=3, padding=1)  # Final reconstruction layer
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.up1(x)  # Upsample to (25, 7)
        x = F.relu(self.conv4(x))  # (B, 64, 25, 7)
        
        x = self.up2(x)  # Upsample to (50, 7)
        x = F.relu(self.conv5(x))  # (B, 32, 50, 7)
        
        x = self.up3(x)  # Upsample to (100, 7)
        x = F.relu(self.conv6(x))  # (B, 16, 100, 7)
        
        x = self.conv7(x)  # Reconstruct to (B, 1, 100, 7)

        return self.sigmoid(x)  # Ensure output is in range [0, 1]

class encoder_decoder_model(pl.LightningModule):
    """
        Encoder-Decoder with Classifier:
        - Combines two encoder-decoder branches (energy and peak) for reconstruction, and a classification branch for supervised learning.

        Input:
            - energy_input: Tensor of shape (B, 1, 100, 7).
            - peak_input: Tensor of shape (B, 1, 100, 7).
        Output:
            - energy_reconstructed: Tensor of shape (B, 1, 100, 7).
            - peak_reconstructed: Tensor of shape (B, 1, 100, 7).
            - classification_output: Tensor of shape (B, num_classes), where `num_classes` is the number of target classes.
    """
    def __init__(self, num_classes=4):
        super().__init__()
        # Two encoder-decoder branches
        self.energy_encoder = Encoder()
        self.energy_decoder = Decoder()
        
        self.peak_encoder = Encoder()
        self.peak_decoder = Decoder()
        
        # Classifier layers
        self.dropout = nn.Dropout(0.2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(640, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, energy_input, peak_input):
        # Encoder-decoder for reconstruction
        energy_encoded = self.energy_encoder(energy_input)
        energy_reconstructed = self.energy_decoder(energy_encoded.unsqueeze(-1))
        
        peak_encoded = self.peak_encoder(peak_input)
        peak_reconstructed = self.peak_decoder(peak_encoded.unsqueeze(-1))
        
        # Classification branch
        x_concat = torch.cat((energy_encoded, peak_encoded), dim=2)  # Shape: (B, 64, 10)
        x_concat = x_concat.permute(0, 2, 1)  # Shape: (B, 10, 64)
        
        x_dropout = self.dropout(x_concat)
        x_flatten = self.flatten(x_dropout)  # Shape: (B, 640)
        
        x_fc1 = F.relu(self.fc1(x_flatten))
        x_bn2 = self.bn2(x_fc1)
        x_fc2 = self.fc2(x_bn2)
        classification_output = self.softmax(x_fc2)
        
        return energy_reconstructed, peak_reconstructed, classification_output
    def training_step(self, batch, batch_idx):
        energy_input, peak_input, y = batch
        
        # Forward pass
        energy_recon, peak_recon, logits = self(energy_input, peak_input)
        
        # Reconstruction loss (MSE) + Classification loss (CrossEntropy)
        recon_loss = F.mse_loss(energy_recon, energy_input) + F.mse_loss(peak_recon, peak_input)
        class_loss = F.cross_entropy(logits, y)
        
        loss = recon_loss + class_loss
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)
        return optimizer

class EncoderDecoderPretrain(pl.LightningModule):
    """
        Encoder-Decoder Pretraining Class:
        - Pretrains the energy and peak encoder-decoder branches using reconstruction loss (MSE loss) on unlabeled data.

        Input:
            - energy_input: Tensor of shape (B, 1, 100, 7).
            - peak_input: Tensor of shape (B, 1, 100, 7).
        Output:
            - energy_reconstructed: Tensor of shape (B, 1, 100, 7).
            - peak_reconstructed: Tensor of shape (B, 1, 100, 7).
    """
    def __init__(self, energy_encoder, energy_decoder, peak_encoder, peak_decoder, lr=1e-3):
        super().__init__()
        self.energy_encoder = energy_encoder
        self.energy_decoder = energy_decoder
        self.peak_encoder = peak_encoder
        self.peak_decoder = peak_decoder
        self.lr = lr

    def forward(self, energy_input, peak_input):
        energy_encoded = self.energy_encoder(energy_input)
        energy_reconstructed = self.energy_decoder(energy_encoded.unsqueeze(-1))
        
        peak_encoded = self.peak_encoder(peak_input)
        peak_reconstructed = self.peak_decoder(peak_encoded.unsqueeze(-1))
        
        return energy_reconstructed, peak_reconstructed

    def training_step(self, batch, batch_idx):
        energy_input, peak_input = batch
        energy_recon, peak_recon = self(energy_input, peak_input)
        
        recon_loss = F.mse_loss(energy_recon, energy_input) + F.mse_loss(peak_recon, peak_input)
        self.log("pretrain_recon_loss", recon_loss)
        return recon_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            list(self.energy_encoder.parameters()) +
            list(self.energy_decoder.parameters()) +
            list(self.peak_encoder.parameters()) +
            list(self.peak_decoder.parameters()),
            lr=self.lr
        )
        return optimizer
class ClassifierTrain(pl.LightningModule):
    """
        Classifier Training Class:
        - Fine-tunes a classification head using pre-trained encoder branches. Encoders are frozen during this phase.

        Input:
            - energy_input: Tensor of shape (B, 1, 100, 7).
            - peak_input: Tensor of shape (B, 1, 100, 7).
        Output:
            - logits: Tensor of shape (B, num_classes), where `num_classes` is the number of target classes.
    """
    def __init__(self, energy_encoder, peak_encoder, fc1, fc2, bn2, lr=1e-3):
        super().__init__()
        self.energy_encoder = energy_encoder
        self.peak_encoder = peak_encoder
        self.fc1 = fc1
        self.fc2 = fc2
        self.bn2 = bn2
        self.lr = lr

        for param in self.energy_encoder.parameters():
            param.requires_grad = False
        for param in self.peak_encoder.parameters():
            param.requires_grad = False

    def forward(self, energy_input, peak_input):
        with torch.no_grad():
            energy_encoded = self.energy_encoder(energy_input)
            peak_encoded = self.peak_encoder(peak_input)
        
        x_concat = torch.cat((energy_encoded, peak_encoded), dim=2)
        x_concat = x_concat.permute(0, 2, 1)
        x_flatten = x_concat.flatten(start_dim=1)

        x_fc1 = F.relu(self.fc1(x_flatten))
        x_bn2 = self.bn2(x_fc1)
        logits = self.fc2(x_bn2)
        return logits

    def training_step(self, batch, batch_idx):
        energy_input, peak_input, labels = batch
        logits = self(energy_input, peak_input)
        class_loss = F.cross_entropy(logits, labels)
        self.log("class_loss", class_loss)
        return class_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            list(self.fc1.parameters()) +
            list(self.fc2.parameters()) +
            list(self.bn2.parameters()),
            lr=self.lr
        )
        return optimizer
def train_ssae_lightning(model, unlabeled_loader, labeled_loader, num_pretrain_epochs, num_classifier_epochs):
    """
        Training Workflow Function:
        - This function handles the two-phase training process:
            1. Pretraining the encoder-decoder branches using reconstruction loss.
            2. Fine-tuning the classification head using supervised data.

        Parameters:
            - model: Instance of the `encoder_decoder_model`.
            - unlabeled_loader: DataLoader for unlabeled data (used for pretraining).
            - labeled_loader: DataLoader for labeled data (used for classification).
            - num_pretrain_epochs: Number of epochs for pretraining.
            - num_classifier_epochs: Number of epochs for classifier fine-tuning.
    """
    # Pretraining
    pretrain_model = EncoderDecoderPretrain(
        energy_encoder=model.energy_encoder,
        energy_decoder=model.energy_decoder,
        peak_encoder=model.peak_encoder,
        peak_decoder=model.peak_decoder
    )
    pretrain_trainer = pl.Trainer(max_epochs=num_pretrain_epochs)
    pretrain_trainer.fit(pretrain_model, unlabeled_loader)
    
    # Classifier Training
    classifier_model = ClassifierTrain(
        energy_encoder=model.energy_encoder,
        peak_encoder=model.peak_encoder,
        fc1=model.fc1,
        fc2=model.fc2,
        bn2=model.bn2
    )
    classifier_trainer = pl.Trainer(max_epochs=num_classifier_epochs)
    classifier_trainer.fit(classifier_model, labeled_loader)


In [14]:
import torch
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
import torch.nn.functional as F

# Dummy Dataset for Testing
class DummyDataset(Dataset):
    def __init__(self, num_samples=16, seq_len=100, channels=7, num_classes=4, labeled=True):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.channels = channels
        self.num_classes = num_classes
        self.labeled = labeled
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        energy_data = torch.rand(1, self.seq_len, self.channels)  # Shape: (1, 100, 7)
        peak_data = torch.rand(1, self.seq_len, self.channels)    # Shape: (1, 100, 7)
        
        if self.labeled:
            label = torch.randint(0, self.num_classes, (1,)).item()
            return energy_data, peak_data, label
        else:
            return energy_data, peak_data

# SSAE Model Testing Pipeline
def test_ssae_pipeline_lightning():
    print("Initializing SSAE Model...")
    # Replace with your main model initialization
    model = encoder_decoder_model(num_classes=4)
    
    # Hyperparameters
    batch_size = 4
    num_unlabeled_samples = 100  # Unlabeled dataset for pretraining
    num_labeled_samples = 40     # Labeled dataset for classifier training
    num_pretrain_epochs = 5
    num_classifier_epochs = 5
    num_test_samples = 16        # Test dataset size
    
    # Data Loaders
    print("Creating Dummy Unlabeled Dataset...")
    unlabeled_dataset = DummyDataset(num_samples=num_unlabeled_samples, labeled=False)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True)
    
    print("Creating Dummy Labeled Dataset...")
    labeled_dataset = DummyDataset(num_samples=num_labeled_samples, labeled=True)
    labeled_loader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)
    
    print("Creating Dummy Test Dataset...")
    test_dataset = DummyDataset(num_samples=num_test_samples, labeled=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Step 1: Pre-train Encoder-Decoder on Unlabeled Data
    print("Step 1: Pre-training Encoder-Decoder on Unlabeled Data...")
    pretrain_model = EncoderDecoderPretrain(
        energy_encoder=model.energy_encoder,
        energy_decoder=model.energy_decoder,
        peak_encoder=model.peak_encoder,
        peak_decoder=model.peak_decoder
    )
    pretrain_trainer = pl.Trainer(max_epochs=num_pretrain_epochs)
    pretrain_trainer.fit(pretrain_model, unlabeled_loader)
    print("Pre-training Completed!\n")
    
    # Step 2: Train Classifier on Labeled Data
    print("Step 2: Training Classifier on Labeled Data...")
    classifier_model = ClassifierTrain(
        energy_encoder=model.energy_encoder,
        peak_encoder=model.peak_encoder,
        fc1=model.fc1,
        fc2=model.fc2,
        bn2=model.bn2
    )
    classifier_trainer = pl.Trainer(max_epochs=num_classifier_epochs)
    classifier_trainer.fit(classifier_model, labeled_loader)
    print("Classifier Training Completed!\n")
    
    # Step 3: Test the Model on Test Data
    print("Step 3: Testing the Model on Test Data...")
    classifier_model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    
    with torch.no_grad():
        for energy_input, peak_input, labels in test_loader:
            logits = classifier_model(energy_input, peak_input)
            predicted = torch.argmax(logits, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")
    print("Testing Completed Successfully!")

if __name__ == "__main__":
    test_ssae_pipeline_lightning()


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type    | Params | Mode 
---------------------------------------------------
0 | energy_encoder | Encoder | 65.0 K | train
1 | energy_decoder | Decoder | 60.2 K | train
2 | peak_encoder   | Encoder | 65.0 K | train
3 | peak_decoder   | Decoder | 60.2 K | train
---------------------------------------------------
250 K     Trainable params
0         Non-trainable params
250 K     Total params
1.001     Total estimated model params size (MB)
42        Modules in train mode
0         Modules in eval mode


Initializing SSAE Model...
Creating Dummy Unlabeled Dataset...
Creating Dummy Labeled Dataset...
Creating Dummy Test Dataset...
Step 1: Pre-training Encoder-Decoder on Unlabeled Data...
Epoch 4: 100%|██████████| 25/25 [00:01<00:00, 16.68it/s, v_num=4]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 25/25 [00:01<00:00, 16.37it/s, v_num=4]
Pre-training Completed!

Step 2: Training Classifier on Labeled Data...


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | energy_encoder | Encoder     | 65.0 K | train
1 | peak_encoder   | Encoder     | 65.0 K | train
2 | fc1            | Linear      | 82.0 K | train
3 | fc2            | Linear      | 516    | train
4 | bn2            | BatchNorm1d | 256    | train
-------------------------------------------------------
82.8 K    Trainable params
130 K     Non-trainable params
212 K     Total params
0.851     Total estimated model params size (MB)
27        Modules in train mode
0         Modules in eval mode


Epoch 4: 100%|██████████| 10/10 [00:00<00:00, 30.81it/s, v_num=5]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 10/10 [00:00<00:00, 30.00it/s, v_num=5]
Classifier Training Completed!

Step 3: Testing the Model on Test Data...
Test Accuracy: 12.50%
Testing Completed Successfully!
