In [198]:
import torch
from data import load_traindata
device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset,DataLoader
from torch.utils.data import random_split
from scipy.spatial.distance import directed_hausdorff
from augmentation import augment

mps


In [199]:
torch.manual_seed(42)

<torch._C.Generator at 0x10c3d3490>

In [200]:
num_subclasses = 100
epochs = 10000
seq_size = 250
batch_size = 32
num_aug = 5
split_size=0.8
lr = 5e-5

In [201]:
X, _ = load_traindata(num_subclasses)
X = torch.tensor(X, dtype=torch.float32)
X = X.reshape(int(num_subclasses), int(1000 / seq_size), seq_size, 12)
X = X.reshape(int(num_subclasses * int(1000 / seq_size)), seq_size, 12)
train_size = int(split_size*len(X))
test_size = len(X) - train_size
X,X_test = random_split(X,[train_size, test_size])
X = augment(X)
X_input = X[:, :, 0]  # First channel
Y_target = X[:, :, 1]  # Second channel


X_test_tensors = [X_test.dataset[idx] for idx in range(len(X_test))] 
X_test_tensor = torch.stack(X_test_tensors) 
X_t = X_test_tensor[:, :, 0]  
Y_t = X_test_tensor[:, :, 1]

(9514, 28)


In [202]:
class Generator(nn.Module):
    def __init__(self, in_channels=2, out_channels=1, num_filters=32):  # Modified in_channels
        super(Generator, self).__init__()
        self.enc1 = nn.Sequential(
            nn.Conv1d(in_channels, num_filters, 4, 2, 1),  # 500 → 250
            nn.LeakyReLU(0.2)
        )
        self.enc2 = nn.Sequential(
            nn.Conv1d(num_filters, num_filters * 2, 4, 2, 1),  # 250 → 125
            nn.BatchNorm1d(num_filters * 2),
            nn.LeakyReLU(0.2)
        )
        self.enc3 = nn.Sequential(
            nn.Conv1d(num_filters * 2, num_filters * 4, 4, 2, 1),  # 125 → 63
            nn.BatchNorm1d(num_filters * 4),
            nn.LeakyReLU(0.2)
        )
        self.enc4 = nn.Sequential(
            nn.Conv1d(num_filters * 4, num_filters * 8, 4, 2, 1),  # 63 → 32
            nn.BatchNorm1d(num_filters * 8),
            nn.LeakyReLU(0.2)
        )
        self.enc5 = nn.Sequential(
            nn.Conv1d(num_filters * 8, num_filters * 8, 4, 2, 1),  # 32 → 16
            nn.BatchNorm1d(num_filters * 8),
            nn.LeakyReLU(0.2)
        )
        self.enc6 = nn.Sequential(
            nn.Conv1d(num_filters * 8, num_filters * 8, 4, 2, 1),  # 16 → 8
            nn.BatchNorm1d(num_filters * 8),
            nn.LeakyReLU(0.2)
        )

        # Decoder
        self.dec1 = nn.Sequential(
            nn.ConvTranspose1d(num_filters * 8, num_filters * 8, 4, 2, 1),  # 8 → 16
            nn.BatchNorm1d(num_filters * 8),
            nn.Dropout(0.2),
            nn.ReLU()
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose1d(num_filters * 8, num_filters * 8, 4, 2, 1),  # 16 → 32
            nn.BatchNorm1d(num_filters * 8),
            nn.Dropout(0.2),
            nn.ReLU()
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose1d(num_filters * 8, num_filters * 4, 4, 2, 1),  # 32 → 63
            nn.BatchNorm1d(num_filters * 4),
            nn.ReLU()
        )
        self.dec4 = nn.Sequential(
            nn.ConvTranspose1d(num_filters * 4, num_filters * 2, 4, 2, 1, output_padding=1),  # 63 → 125
            nn.BatchNorm1d(num_filters * 2),
            nn.ReLU()
        )
        self.dec5 = nn.Sequential(
            nn.ConvTranspose1d(num_filters * 2, num_filters, 4, 2, 1),  # 125 → 250
            nn.BatchNorm1d(num_filters),
            nn.ReLU()
        )
        self.dec6 = nn.ConvTranspose1d(num_filters, out_channels, 4, 2, 1)  # 250 → 500
        self.final_activation = nn.LeakyReLU(0.2)  # Use Tanh for normalized output

    def match_size(self, x, target):
        diff = x.size(2) - target.size(2)
        if diff > 0:
            return x[:, :, :-diff]
        elif diff < 0:
            return nn.functional.pad(x, (0, -diff))
        else:
            return x

    def forward(self, x):
        # x shape: (batch_size, 2, seq_len)
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)

        d1 = self.match_size(self.dec1(e6), e5) + e5
        d2 = self.match_size(self.dec2(d1), e4) + e4
        d3 = self.match_size(self.dec3(d2), e3) + e3
        d4 = self.match_size(self.dec4(d3), e2) + e2
        d5 = self.match_size(self.dec5(d4), e1) + e1
        d6 = self.dec6(d5)
        return self.final_activation(d6)

In [203]:
class Discriminator(nn.Module):
    def __init__(self, seq_size, num_filters=32):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=2, out_channels=num_filters, kernel_size=4, stride=2, padding=1)
        self.leakyRelu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv1d(num_filters, num_filters*2, 4, 2, 1)
        self.bn1 = nn.BatchNorm1d(num_filters*2)
        self.conv3 = nn.Conv1d(num_filters*2, num_filters*4, 4, 2, 1)
        self.bn2 = nn.BatchNorm1d(num_filters*4)
        self.conv4 = nn.Conv1d(num_filters*4, num_filters*8, 4, 2, 1)
        self.bn3 = nn.BatchNorm1d(num_filters*8)
        self.conv5 = nn.Conv1d(num_filters*8, num_filters*8, 4, 2, 1)
        self.bn4 = nn.BatchNorm1d(num_filters*8)
        self.conv6 = nn.Conv1d(num_filters*8, 1, 4, 2, 1)
        self.op = nn.Linear(3, 1)  # Correct input size
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.leakyRelu(x)
        x = self.conv2(x)
        x = self.bn1(x)
        x = self.leakyRelu(x)
        x = self.conv3(x)
        x = self.bn2(x)
        x = self.leakyRelu(x)
        x = self.conv4(x)
        x = self.bn3(x)
        x = self.leakyRelu(x)
        x = self.conv5(x)
        x = self.bn4(x)
        x = self.leakyRelu(x)
        x = self.conv6(x)
        x = self.leakyRelu(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.op(x)
        x = self.sigmoid(x)
        return x

In [204]:
dataset = TensorDataset(X_input, Y_target)
test_dataset = TensorDataset(X_t, Y_t)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
generator = Generator().to(device)
discriminator = Discriminator(seq_size).to(device)
#criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0005)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
print(sum(p.numel() for p in generator.parameters())/1e6, 'M parameters for Generator')
print(sum(p.numel() for p in discriminator.parameters())/1e6, 'M parameters for Discriminator')

1.398145 M parameters for Generator
0.437605 M parameters for Discriminator


In [205]:
def plot_losses(d_losses, g_losses):
    def smooth_curve(points, factor=0.9):
        smoothed = []
        for point in points:
            if smoothed:
                smoothed.append(smoothed[-1] * factor + point * (1 - factor))
            else:
                smoothed.append(point)
        return smoothed

    plt.plot(smooth_curve(d_losses), label='D Loss')
    plt.plot(smooth_curve(g_losses), label='G Loss')
    plt.legend()
    plt.show()

In [206]:
counter = 0
def plotWave(X, Y, c=0):
    x_np = X.squeeze(0).detach().cpu().numpy()
    y_np = Y.squeeze(0).detach().cpu().numpy()

    # Plotting
    plt.figure(figsize=(20, 6))
    plt.plot(x_np, color='blue', label='X (Generated Signal)')
    plt.plot(y_np, color='red', label='Y (Real Signal)')
    plt.legend()
    plt.title(f'Wave Comparison - Frame {c}')
    plt.show()

    # MSE
    mse = np.mean((x_np - y_np) ** 2)
    # FD
    fd = max(directed_hausdorff(x_np.reshape(-1, 1), y_np.reshape(-1, 1))[0],
             directed_hausdorff(y_np.reshape(-1, 1), x_np.reshape(-1, 1))[0])

    print(f"Frame {c}:")
    print(f"Mean Squared Error (MSE): {mse:.4f}")
    print(f"Fréchet Distance: {fd:.4f}")

    return c + 1


In [207]:
def total_variation_loss(x):
    return torch.mean(torch.abs(x[:, 1:] - x[:, :-1]))

In [208]:
criterion = nn.BCELoss()
g_losses, d_losses = [], []
for epoch in range(epochs):
    for real_1, real_2 in dataloader:
        real_1, real_2 = real_1.to(device).float(), real_2.to(device).float()
        real_2 += 0.01 * torch.randn_like(real_2)  # Noise augmentation

        # Train Discriminator
        optimizer_d.zero_grad()

        # Generate noise and condition
        noise = torch.randn_like(real_1)
        gen_input = torch.stack([real_1, noise], dim=1)  # (batch, 2, seq_len)
        fake_2 = generator(gen_input)

        # Real and fake inputs for discriminator
        d_real_input = torch.cat([real_2.unsqueeze(1), real_1.unsqueeze(1)], dim=1)
        d_fake_input = torch.cat([fake_2.detach(), real_1.unsqueeze(1)], dim=1)

        # Discriminator loss
        real_labels = torch.ones(real_1.size(0), 1, device=device)
        fake_labels = torch.zeros(real_1.size(0), 1, device=device)

        d_real_loss = criterion(discriminator(d_real_input), real_labels)
        d_fake_loss = criterion(discriminator(d_fake_input), fake_labels)
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()
        gen_input = torch.stack([real_1, torch.randn_like(real_1)], dim=1)
        fake_2 = generator(gen_input)
        d_fake_input = torch.cat([fake_2, real_1.unsqueeze(1)], dim=1)
        g_loss = criterion(discriminator(d_fake_input), real_labels) #+  0.1*total_variation_loss(fake_2)
        g_loss.backward()
        optimizer_g.step()
    g_losses.append(g_loss.item())
    d_losses.append(d_loss.item())
    if (epoch+1)%100==0:
        plot_losses(d_losses, g_losses)
        with torch.no_grad():
            for real_1, real_2 in dataloader:
                real_1, real_2 = real_1.to(torch.float32),real_2.to(torch.float32)
                real_1, real_2 = real_1.to(device), real_2.to(device)
                gen_input = torch.stack([real_1, torch.randn_like(real_1)], dim=1) 
                fake_2 = generator(gen_input)
                counter = plotWave(fake_2[0],real_2[0], counter)
                counter = plotWave(fake_2[12],real_2[12], counter) #random output
                break
    print(f"Epoch {epoch+1}/{epochs}, Discrimiator Loss: {d_loss.item()}, Generator Loss: {g_loss.item()}")

Epoch 1/10000, Discrimiator Loss: 1.3817598819732666, Generator Loss: 0.7448569536209106
Epoch 2/10000, Discrimiator Loss: 0.9144365787506104, Generator Loss: 1.910752296447754
Epoch 3/10000, Discrimiator Loss: 0.659257709980011, Generator Loss: 3.731301784515381
Epoch 4/10000, Discrimiator Loss: 0.3871452212333679, Generator Loss: 4.2038726806640625
Epoch 5/10000, Discrimiator Loss: 0.2291051745414734, Generator Loss: 4.9710307121276855
Epoch 6/10000, Discrimiator Loss: 0.13461874425411224, Generator Loss: 5.471002578735352
Epoch 7/10000, Discrimiator Loss: 0.08614402264356613, Generator Loss: 6.002068519592285
Epoch 8/10000, Discrimiator Loss: 0.059051111340522766, Generator Loss: 6.324778079986572
Epoch 9/10000, Discrimiator Loss: 0.03583770617842674, Generator Loss: 6.503056526184082
Epoch 10/10000, Discrimiator Loss: 0.03300109878182411, Generator Loss: 5.915240287780762
Epoch 11/10000, Discrimiator Loss: 0.02315041795372963, Generator Loss: 6.520235538482666
Epoch 12/10000, Discr

KeyboardInterrupt: 