In [17]:
# Cell 1: Imports and Data Loading
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from ast import literal_eval
import matplotlib.pyplot as plt
%matplotlib inline

# Load and preprocess data
data = pd.read_csv('adfa_ld_processed.csv')
data['sequence'] = data['sequence'].apply(literal_eval)

# Find max length and pad sequences
max_len = max(len(seq) for seq in data['sequence'])
def pad_sequence(seq):
    return np.pad([int(x) for x in seq], 
                 (0, max_len - len(seq)), 
                 'constant')

X = np.array([pad_sequence(seq) for seq in data['sequence']])
y = pd.get_dummies(data['label']).values

# Normalize to [0,1]
X = (X - X.min()) / (X.max() - X.min())

# Convert to tensors
X_tensor = torch.FloatTensor(X).unsqueeze(1)
y_tensor = torch.FloatTensor(y)

# Create DataLoader
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
print(f"Data loaded with shape: {X_tensor.shape}")

Data loaded with shape: torch.Size([1579, 1, 2948])


In [18]:
# Cell 2: LSTM-GAN Architecture
import torch.nn as nn

class LSTMGenerator(nn.Module):
    def __init__(self, latent_dim, sequence_length, num_classes=2, hidden_dim=512):
        super().__init__()
        
        self.lstm = nn.LSTM(
            input_size=latent_dim + 32,
            hidden_size=hidden_dim,
            num_layers=3,
            batch_first=True,
            dropout=0.3
        )
        
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            dropout=0.1
        )
        
        self.label_embedding = nn.Embedding(num_classes, 32)
        
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim, sequence_length),
            nn.Sigmoid()
        )
        
    def forward(self, z, labels, seq_len=100):
        batch_size = z.size(0)
        label_embed = self.label_embedding(labels)
        z_sequence = z.unsqueeze(1).repeat(1, seq_len, 1)
        label_sequence = label_embed.unsqueeze(1).repeat(1, seq_len, 1)
        lstm_input = torch.cat([z_sequence, label_sequence], dim=-1)
        
        lstm_out, _ = self.lstm(lstm_input)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        
        return self.output_layer(attn_out.mean(dim=1))

# Cell 3: Complete Architecture and Setup
class LSTMDiscriminator(nn.Module):
    def __init__(self, sequence_length, num_classes=2, hidden_dim=512):
        super().__init__()


        
        self.lstm = nn.LSTM(
            input_size=sequence_length + 32,
            hidden_size=hidden_dim,
            num_layers=3,
            batch_first=True,
            dropout=0.3
        )
        self.batch_norm = nn.BatchNorm1d(hidden_dim)

        
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            dropout=0.1
        )
        
        self.label_embedding = nn.Embedding(num_classes, 32)
        
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, labels):
        label_embed = self.label_embedding(labels)
        label_sequence = label_embed.unsqueeze(1).repeat(1, x.size(1), 1)
        lstm_input = torch.cat([x, label_sequence], dim=-1)
        
        lstm_out, _ = self.lstm(lstm_input)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        
        return self.output_layer(attn_out.mean(dim=1))

# Configuration
config = {
    'n_epochs': 200,
    'batch_size': 64,
    'lr_g': 0.0002,
    'lr_d': 0.0005,
    'beta1': 0.5,
    'beta2': 0.999,
    'latent_dim': 128,
    'sequence_length': X_tensor.shape[2],
    'num_classes': y_tensor.shape[1],
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'gp_lambda': 5.0,
    'n_critic': 2
}

# Initialize models
generator = LSTMGenerator(
    config['latent_dim'], 
    config['sequence_length']
).to(config['device'])

discriminator = LSTMDiscriminator(
    config['sequence_length']
).to(config['device'])

# Setup optimizers
g_optimizer = Adam(generator.parameters(), lr=config['lr_g'], betas=(config['beta1'], config['beta2']))
d_optimizer = Adam(discriminator.parameters(), lr=config['lr_d'], betas=(config['beta1'], config['beta2']))



# Add schedulers
g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.995)
d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.995)

# Loss function
adversarial_loss = nn.BCELoss()

print(f"Models initialized on: {config['device']}")

print("Models initialized successfully!")

Models initialized on: cpu
Models initialized successfully!


In [19]:
# Gradient Penalty
def compute_gradient_penalty(discriminator, real_samples, fake_samples, labels):
    alpha = torch.rand(real_samples.size(0), 1, 1).to(config['device'])
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = discriminator(interpolates, labels)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Label smoothing
def smooth_labels(size):
    return torch.ones(size).uniform_(0.8, 1.0).to(config['device'])


In [20]:
# Cell 4: Fixed Training Loop Implementation
# Cell 4: Fixed Training Loop Implementation
import os
from datetime import datetime
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Training Loop
def train_lstm_gan():
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    save_dir = f'lstm_gan_checkpoints_{timestamp}'
    os.makedirs(save_dir, exist_ok=True)
    
    d_losses, g_losses = [], []
    
    try:
        for epoch in range(config['n_epochs']):
            d_epoch_loss, g_epoch_loss = 0, 0
            pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{config["n_epochs"]}')
            
            # Calculate noise factor for instance noise
            noise_factor = max(0.1 * (1.0 - epoch/config['n_epochs']), 0)
            
            for batch_idx, (real_data, labels) in enumerate(pbar):
                try:
                    batch_size = real_data.size(0)
                    
                    # Reshape and add instance noise
                    real_data = real_data.view(batch_size, -1, config['sequence_length']).to(config['device'])
                    real_data += noise_factor * torch.randn_like(real_data)
                    labels = labels.to(config['device'])
                    
                    # Smooth labels
                    valid = smooth_labels((batch_size, 1))
                    fake = torch.zeros(batch_size, 1).to(config['device'])
                    
                    # Train Generator
                    if batch_idx % config['n_critic'] == 0:
                        g_optimizer.zero_grad()
                        z = torch.randn(batch_size, config['latent_dim']).to(config['device'])
                        generated_data = generator(z, labels.argmax(1))
                        generated_data = generated_data.unsqueeze(1)
                        
                        validity = discriminator(generated_data, labels.argmax(1))
                        g_loss = adversarial_loss(validity, valid)
                        
                        g_loss.backward()
                        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
                        g_optimizer.step()
                        
                        g_epoch_loss += g_loss.item()
                    
                    # Train Discriminator
                    d_optimizer.zero_grad()
                    
                    # Real data
                    real_validity = discriminator(real_data, labels.argmax(1))
                    real_loss = adversarial_loss(real_validity, valid)
                    
                    # Fake data
                    z = torch.randn(batch_size, config['latent_dim']).to(config['device'])
                    fake_data = generator(z, labels.argmax(1)).unsqueeze(1).detach()
                    fake_validity = discriminator(fake_data, labels.argmax(1))
                    fake_loss = adversarial_loss(fake_validity, fake)
                    
                    # Gradient penalty
                    gp = compute_gradient_penalty(
                        discriminator, real_data, fake_data, labels.argmax(1)
                    )
                    
                    # Total discriminator loss
                    d_loss = (real_loss + fake_loss) / 2 + (config['gp_lambda'] * gp / batch_size)
                    
                    d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
                    d_optimizer.step()
                    
                    d_epoch_loss += d_loss.item()
                    
                    # Update progress bar
                    pbar.set_postfix({
                        'D_loss': f'{d_loss.item():.4f}',
                        'G_loss': f'{g_loss.item():.4f}' if 'g_loss' in locals() else 'N/A'
                    })
                    
                except RuntimeError as e:
                    print(f"Batch error: {e}")
                    continue
            
            # Calculate average losses
            n_batches = len(dataloader)
            avg_d_loss = d_epoch_loss / n_batches
            avg_g_loss = g_epoch_loss / (n_batches // config['n_critic'])
            d_losses.append(avg_d_loss)
            g_losses.append(avg_g_loss)
            
            print(f"\nEpoch [{epoch+1}/{config['n_epochs']}] "
                  f"D_loss: {avg_d_loss:.4f} G_loss: {avg_g_loss:.4f}")
            
            # Step schedulers
            g_scheduler.step()
            d_scheduler.step()

            # Save checkpoint
            if (epoch + 1) % 50 == 0:
                checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pt')
                torch.save({
                    'epoch': epoch,
                    'generator_state_dict': generator.state_dict(),
                    'discriminator_state_dict': discriminator.state_dict(),
                    'g_optimizer_state_dict': g_optimizer.state_dict(),
                    'd_optimizer_state_dict': d_optimizer.state_dict(),
                    'g_loss': avg_g_loss,
                    'd_loss': avg_d_loss
                }, checkpoint_path)
                
    except Exception as e:
        print(f"Training error: {e}")
        
    return d_losses, g_losses

# Start training
d_losses, g_losses = train_lstm_gan()

# Plot results
plt.figure(figsize=(10,5))
plt.plot(d_losses, label='Discriminator Loss')
plt.plot(g_losses, label='Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('LSTM-GAN Training Progress')
plt.legend()
plt.grid(True)
plt.show()

Epoch 1/200:   0%|          | 0/50 [00:00<?, ?it/s]


Epoch [1/200] D_loss: 0.4116 G_loss: 9.0978


Epoch 2/200:   0%|          | 0/50 [00:00<?, ?it/s]


Epoch [2/200] D_loss: 0.3360 G_loss: 10.7018


Epoch 3/200:   0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 