In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim


In [2]:
from CoRe_Dataloader import dataset

In [3]:
class LSTMDiscriminatorHOLS(nn.Module):
    """An LSTM based discriminator. It expects a sequence as input and outputs a probability for each element. 
    Args:
        in_dim: Input noise dimensionality
        n_layers: number of lstm layers
        hidden_dim: dimensionality of the hidden layer of lstms
    Inputs: sequence of shape (batch_size, seq_len, in_dim)
    Output: sequence of shape (batch_size, seq_len, num_classes)
    """

    def __init__(self, in_dim,n_classes = 19, n_layers=1, hidden_dim=256):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim

        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layers, batch_first=True)
        self.linear_hidden = nn.Sequential(nn.Linear(hidden_dim, n_classes), nn.Sigmoid())

    def forward(self, input):
        batch_size, seq_len = input.size(0), input.size(1)
        h_0 = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to("cuda:0")
        c_0 = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to("cuda:0")

        recurrent_features, (h_out,c_out) = self.lstm(input, (h_0, c_0))
        return self.linear_hidden(h_out[-1])

In [4]:
batch_size = 16
input_dimenstions = 1
layers = 1
hidden_size = 2048
dataloader = DataLoader(dataset,batch_size=16,shuffle = True)

disc1 = LSTMDiscriminatorHOLS(input_dimenstions,n_layers=layers,hidden_dim=hidden_size).to("cuda:0")
optimizer1 = optim.Adam(disc1.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss().to(dataset.device)


In [5]:
print("epoch |  bnum  | errD1  |")
for epoch in range(30):
    for i, (data,label) in enumerate(dataloader, 0):
        data = data.unsqueeze(-1).to(torch.float)
        disc1.zero_grad()
        coutput = disc1(data)
        clabel = label[:,0].to(torch.long)
        errD1 = criterion(coutput,clabel)
        errD1.backward()
        optimizer1.step()

        print(f"{epoch:5} | {i:6} | {errD1.item():6.3}",end = "\r",flush=True)
        if i%100 == 0:
            print()
    print("\nepoch finished: ",epoch)

epoch |  bnum  | errD1  |
    0 |      0 |   2.95
    0 |    100 |   2.32
    0 |    140 |   2.45
epoch finished:  0
    1 |      0 |    2.5
    1 |     96 |    2.5

KeyboardInterrupt: 