In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

In [None]:

def fsin(f):
    rand = np.random.random(size=5)
    x = np.linspace(-1, 1, 320)
    y = np.sin((x+rand[0]) * f) * rand[1] + rand[2]
    return y
    

In [None]:
class sorq(Dataset):
    def __init__(
            self,
            length=6000,
            device="cuda:0"
    ):
        self.classes = 4
        self.length = length
        self.device = device

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        rands = np.random.rand(2)
        output = fsin((index%self.classes)*4+2)
        return (torch.tensor(output).view(-1,1).to(self.device).to(torch.float),
                torch.tensor([index%self.classes]).to(self.device).to(torch.float))

    def __str__(self) -> str:
        return f"{self.length}"

dataset = sorq(length=1_0_000)

In [None]:
for i in range(0,20):
    plt.plot(dataset[i][0].detach().cpu())
plt.show()

In [None]:
class LSTMDiscriminatorRF(nn.Module):
    """An LSTM based discriminator. It expects a sequence as input and outputs a probability for each element. 
    Args:
        in_dim: Input noise dimensionality
        n_layers: number of lstm layers
        hidden_dim: dimensionality of the hidden layer of lstms
    Inputs: sequence of shape (batch_size, seq_len, in_dim)
    Output: sequence of shape (batch_size, seq_len, 1)
    """

    def __init__(self, in_dim, out_dim = 1,seq_len = 320, n_layers=1, hidden_dim=256):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layers, batch_first=True)
        self.linear_recurrent = nn.Sequential(nn.Linear(hidden_dim*seq_len, out_dim))

    def forward(self, input):
        batch_size, seq_len = input.size(0), input.size(1)
        h_0 = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to("cuda:0")
        c_0 = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to("cuda:0")

        rf, (h_out,c_out) = self.lstm(input, (h_0, c_0))
        shape = rf.shape
        outputs = self.linear_recurrent(rf.reshape(-1,shape[1]*shape[2]))
        return outputs

In [None]:
batch_size = 16
input_dimenstions = 1
layers = 6
hidden_size = 1024
dataloader = DataLoader(dataset,batch_size=16,shuffle = True)

disc1 = LSTMDiscriminatorRF(input_dimenstions,n_layers=layers,hidden_dim=hidden_size,out_dim=4).to("cuda:0")
optimizer1 = optim.RAdam(disc1.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss().to(dataset.device)
scheduler1 = optim.lr_scheduler.MultiStepLR(optimizer1, milestones=[1,2,5], gamma=0.01)

In [None]:
print("epoch |  bnum  | errD1  | lr")
for epoch in range(100):
    for i, (data,label) in enumerate(dataloader, 0):
        # print(data.device,label.device)
        optimizer1.zero_grad()
        disc1.zero_grad()
        output = disc1(data)
        errD1 = criterion(output,label.view(-1).to("cuda:0").to(torch.long))
        errD1.backward()
        optimizer1.step()
        print(f"{epoch:5} | {i:6} | {errD1.item():6.3} | {scheduler1.get_last_lr()}",end = "\r",flush=True)
        if i%100 == 0:
            print()
    scheduler1.step()
    print("\nepoch finished: ",epoch)

In [None]:
c,label = next(iter(DataLoader(dataset=dataset,batch_size=100,shuffle=True)))
output = disc1(c)
output.shape,label.shape