## CS294 Homework 1a

In [None]:
import torch
import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

quiet = False

def train(model, train_loader, optimizer, epoch, grad_clip=None):
    model.train()

    train_losses = []
    for x in train_loader:
        x = x.contiguous()
        loss = model.loss(x)
        optimizer.zero_grad()
        loss.backward()
        if grad_clip:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        train_losses.append(loss.item())
    return train_losses

def eval_loss(model, data_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x in data_loader:
            x = x.contiguous()
            loss = model.loss(x)
            total_loss += loss * x.shape[0]
    avg_loss = total_loss / len(data_loader.dataset)

    return avg_loss.item()


def train_epochs(model, train_loader, test_loader, train_args):
    epochs, lr = train_args['epochs'], train_args['lr']
    grad_clip = train_args.get('grad_clip', None)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []
    test_losses = [eval_loss(model, test_loader)]
    for epoch in range(epochs):
        model.train()
        train_losses.extend(train(model, train_loader, optimizer, epoch, grad_clip))
        test_loss = eval_loss(model, test_loader)
        test_losses.append(test_loss)
        if not quiet:
            print(f'Epoch {epoch}, Test loss {test_loss:.4f}')

    return train_losses, test_losses


class MLPToBeta(nn.Module):
    def __init__(nn.Module, num_input_features, hidden_size=10):
        super()
        self.num_features = num_input_features
        self.hidden_size  = hidden_size

        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        return self.fc2(x)

    def loss(self, x):
        logits = self.forward(x)
        # There's gonna be shape errors here. What happens when x, which is batch_size x num_features, is input here?
        # How does x go into this? I want to return params for 
        return torch.distributions.beta.Beta(logits)


class Histogram(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.d = d
        self.logits = nn.Parameter(torch.zeros(d), requires_grad=True)

    def loss(self, x):
        logits = self.logits.unsqueeze(0).repeat(x.shape[0], 1) # batch_size x d
        return F.cross_entropy(logits, x.long())

    def get_distribution(self):
        distribution = F.softmax(self.logits, dim=0)
        return distribution.detach().cpu().numpy()

    
def learn_histogram(train_data, test_data, d):
    """
    train_data: An (n_train,) numpy array of integers in {0, ..., d-1}
    test_data: An (n_test,) numpy array of integers in {0, .., d-1}
    d: The number of possible discrete values for random variable x
    dset_id: An identifying number of which dataset is given (1 or 2). Most likely
             used to set different hyperparameters for different datasets

    Returns
    - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
    - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
    - a numpy array of size (d,) of model probabilities
    """

    model = Histogram(d)
    train_loader = data.DataLoader(train_data, batch_size=128, shuffle=True)
    test_loader = data.DataLoader(test_data, batch_size=128)
    train_losses, test_losses = train_epochs(model, train_loader, test_loader, dict(epochs=20, lr=1e-1))
    distribution = model.get_distribution()

    return train_losses, test_losses, distribution

In [None]:
mean = 3000
stddev = 1000
lower_limit = 0
upper_limit = 10000
train_data = torch.distributions.normal.Normal(mean, stddev).sample((1000,))
train_data = train_data[(train_data >= lower_limit) & (train_data <= upper_limit)]
test_data = torch.distributions.normal.Normal(mean, stddev).sample((1000,))
test_data = test_data[(test_data >= lower_limit) & (test_data <= upper_limit)]
train_data[:100]

In [None]:
# Histogram through sampling
# samples = dist.sample((10000,))
# plt.hist((samples / samples.sum()).numpy(), bins=20)

In [None]:
train_losses, test_losses, distribution = plt.hist(train_data.numpy(), bins=20)

In [None]:
train_losses, test_losses, distribution = learn_histogram(train_data, test_data, 20)
distribution

In [None]:
plt.bar(np.arange(20), distribution)