In [3]:
import torch as pt
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim import SGD, Adam
from tqdm import tqdm
DEBUG = 1


In [4]:
def generate_data(data_size = 10000, dim = 1, bound = 100, rand_dist = "uniform"):
    data_list = []
    if rand_dist == "uniform":
        rand_func = pt.rand
    elif rand_dist == "normal":
        rand_func = pt.randn
    else:
        raise NotImplementedError("Must choose from uniform or normal distribution.")
    
    for _ in range(data_size): 
        x = rand_func(dim) * bound
        y = rand_func(dim) * bound
        z = (x * y).unsqueeze(0)
        x = pt.concat((x, y), dim = 0).unsqueeze(0)
        data_list.append((x, z))
        
            
    if DEBUG:
        print(f"generated {data_size} pairs")
        print("Example data:")
        print(f"    x = {x[0, : dim]}")
        print(f"    y = {x[0, dim: ]}")
        print(f"    z =  x * y = {z}")
    return data_list

In [7]:
class MultDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)
    
    @classmethod
    def collate_fn(cls, batch):
        x, z = zip(*batch)
        x = pt.concat(x, dim=0)
        z = pt.concat(z, dim=0)
        return x, z

In [None]:
class LP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)
        self.activation_func = nn.ReLU()
        
    def forward(self, x):
        h = self.input_layer(x)
        h = self.activation_func(h)
        output = self.output_layer(h)
        
        return output
        
        
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, hidden_num, output_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.hidden_num = hidden_num
        
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(hidden_num)])
        self.output_layer = nn.Linear(hidden_size, output_size)
        self.activation_func = nn.ReLU()
    def forward(self, x):
        h = self.input_layer(x)
        h = self.activation_func(h)
        for hidden_layer in self.hidden_layers:
            h = hidden_layer(h)
            h = self.activation_func(h)
        output = self.output_layer(h)
        
        return output
    

In [None]:
dim = 1
epochs = 1000
device = pt.device("cuda")
train_dataset = MultDataset(generate_data(dim=dim, bound=1))
test_dataset = MultDataset(generate_data(1000, dim=dim, bound=10))

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=MultDataset.collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False, collate_fn=MultDataset.collate_fn)
#model = LP(2*dim, 30*dim, dim).to(device)
model = MLP(2*dim, 30*dim, 3, dim).to(device)
loss_fn = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=5e-4)# SGD(model.parameters(), lr = 1e-2)
print(len(test_dataset))

normalized_input = True

for epoch in range(epochs):
    model.train()
    for x, z in tqdm(train_dataloader, desc=f'epoch {epoch}'):
        x = x.to(device)
        z = z.to(device).squeeze()
        optimizer.zero_grad()
        pred = model(x).squeeze()
        loss = loss_fn(pred, z.to(device))
        loss.backward()
        optimizer.step()

    acc = 0
    res = 0
    model.eval()
    for x, z in tqdm(test_dataloader):
        x = x.to(device)
        z = z.to(device).squeeze()
        optimizer.zero_grad()
        pred = model(x).squeeze()
        diff = (z - pred) ** 2
        res += diff.sum()
        acc += (diff < 1).sum()
    if epochs < 200 or (epoch % 20 == 0):
        print(f'acc {acc/1000}')
        print(f'res {res/1000}')
        