In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define MDD Classifier
class MDDClassifier(nn.Module):
    def __init__(self):
        super(MDDClassifier, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=19, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.batch_norm1 = nn.BatchNorm1d(16)
        self.batch_norm2 = nn.BatchNorm1d(32)

        encoder_layers = nn.TransformerEncoderLayer(d_model=32, nhead=4)
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers=2)

        self.fc = nn.Linear(32 + 1, 1)  # Concatenate with time embedding

    def forward(self, x, t):
        x = self.pool(nn.ReLU()(self.batch_norm1(self.conv1(x))))
        x = self.pool(nn.ReLU()(self.batch_norm2(self.conv2(x))))
        x = x.permute(2, 0, 1)  # Reshape for transformer [seq_len, batch_size, feature_dim]
        x = self.transformer(x)
        x = x.mean(dim=0)  # Global average pooling
        x = torch.cat([x, t], dim=1)
        x = self.fc(x)
        return x

# Define Time Embedding
class TimeEmbedding(nn.Module):
    def __init__(self, embed_size, max_time=1000):
        super(TimeEmbedding, self).__init__()
        self.time_embed = nn.Embedding(max_time, embed_size)

    def forward(self, t):
        return self.time_embed((t % self.time_embed.num_embeddings).long())

# Define Hemispherical Embedding
class HemisphericalEmbedding(nn.Module):
    def __init__(self, embed_size):
        super(HemisphericalEmbedding, self).__init__()
        self.embed = nn.Embedding(2, embed_size)  # Assuming 2 hemispheres

    def forward(self, hemisphere):
        return self.embed(hemisphere.long())  # Convert to LongTensor

# Define Functional Region Embedding
class FunctionalRegionEmbedding(nn.Module):
    def __init__(self, embed_size):
        super(FunctionalRegionEmbedding, self).__init__()
        self.embed = nn.Embedding(1000, embed_size)  # Assuming 10 functional regions

    def forward(self, region):
        return self.embed(region.long())  # Convert to LongTensor

# Define Residual Layer
class ResidualLayer(nn.Module):
    def __init__(self, input_dim):
        super(ResidualLayer, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, input_dim, kernel_size=3, padding=1)
        self.bidi_conv = nn.Conv1d(input_dim, input_dim, kernel_size=3, padding=1)  # Bi-directional Conv
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

    def forward(self, x, t_embed):
        res = x
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = x + res  # Skip connection
        x = self.bidi_conv(x)
        x = self.tanh(x)
        return x

# Define Diffusion Model
class DiffusionModel(nn.Module):
    def __init__(self):
        super(DiffusionModel, self).__init__()
        self.residual1 = ResidualLayer(19)  # Updated for multi-channel input
        self.residual2 = ResidualLayer(19)  # Updated for multi-channel input
        self.time_embed = TimeEmbedding(19)
        self.hemi_embed = HemisphericalEmbedding(19)
        self.region_embed = FunctionalRegionEmbedding(19)

    def forward(self, x, t, hemisphere, region):
        t_embed = self.time_embed(t)
        hemi_embed = self.hemi_embed(hemisphere)
        region_embed = self.region_embed(region)

        x = self.residual1(x + t_embed)
        x = self.residual2(x + t_embed + hemi_embed + region_embed)  # ADD & norm part
        return x

# Generate noisy signals using a Markov process and Gaussian noise
def add_noise(x, beta_start=0.1, beta_end=0.2, T=10):
    betas = torch.linspace(beta_start, beta_end, T)
    alphas = 1 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    noise = torch.randn_like(x)
    x_noisy = torch.zeros_like(x)
    for t in range(T):
        x_noisy = torch.sqrt(alphas_cumprod[t]) * x + torch.sqrt(1 - alphas_cumprod[t]) * noise
    return x_noisy

# Example function to generate augmented data
def generate_augmented_data(diffusion_model, x, t, hemisphere, region):
    return diffusion_model(x, t, hemisphere, region)

# Example training loop for each time step t
def train_classifier_for_t(classifier, train_loader, t, epochs=1):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=0.001)
    classifier.train()

    for epoch in range(epochs):
        for x, y, _, _ in train_loader:
            x = add_noise(x)  # Add noise to the input
            t_tensor = torch.tensor([t] * x.size(0)).unsqueeze(1).float()
            optimizer.zero_grad()
            output = classifier(x, t_tensor)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

# Data Preparation
num_samples = 10
num_channels = 19
num_timepoints = 1280

# EEG Signal (random example data)
x_train = torch.randn(num_samples, num_channels, num_timepoints)  # Shape: (10, 19, 1280)

# Labels (random example data)
y_train = torch.randint(0, 2, (num_samples, 1)).float()  # Shape: (10, 1)

# Hemispheres (example data)
hemispheres = torch.randint(0, 2, (num_samples, num_channels))  # Ensure values are 0 or 1, Shape: (10, 19)

# Functional Regions (random example data)
regions = torch.randint(0, 10, (num_samples, num_channels))  # Shape: (10, 19)

# Create a DataLoader for your dataset
train_dataset = TensorDataset(x_train, y_train, hemispheres, regions)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

# Initialize models
classifier = MDDClassifier()
diffusion_model = DiffusionModel()

# Training loop for T time steps
for t in range(10):
    train_classifier_for_t(classifier, train_loader, t)

# Generate augmented data using the diffusion model
for x, y, hemisphere, region in train_loader:
    t = torch.tensor([500] * x.size(0)).unsqueeze(1).float()  # Example time step
    augmented_data = generate_augmented_data(diffusion_model, x, t, hemisphere, region)
    print(f"Augmented Data for time step {t.item()}: {augmented_data}")

# Final output (y') from the MDD classifier
for x, y, hemisphere, region in train_loader:
    t = torch.tensor([500] * x.size(0)).unsqueeze(1).float()  # Example time step
    classifier_output = classifier(x, t)
    y_prime = torch.sigmoid(classifier_output)
    print(f"Classifier output (y') for time step {t.item()}: {y_prime}")


IndexError: index out of range in self