In [66]:
import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

import glob
import time
import json

from typing import Callable
from collections import Counter

from loguru import logger
from tqdm import tqdm
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)

9

In [67]:
class Review:
    def __init__(self, quote: str, score: float) -> None:
        self.quote = quote.strip()
        self.score = score
        self.isacii = self.quote.isascii()
    def __repr__(self) -> str:
        return f'''Rating: {self.score}
{self.review}
'''

class TomatoDataset(Dataset):
    def __init__(self, root_dir: str):
        self.root_dir = root_dir
        self.files = glob.glob(root_dir + '/*.json')
        self.data: list[Review] = []
        self.counter = Counter()
        for fname in self.files:
            with open(fname) as f:
                try:
                    data: dict[str, dict] = json.load(f)
                except json.JSONDecodeError:
                    logger.warning(f'Failed to load {fname}')
                    continue
            for _, item in data.items():
                assert item['rating'] == item['score']
                rating = item['rating']
                quote = item['quote']
                
                if self.counter[rating] > 200: # TEMP
                    continue
                self.counter[rating] += 1
                
                review = Review(quote, rating)
                if review.isacii:
                    self.data.append(review)
        logger.info(f"{self.counter}")
        logger.info(f'Loaded {len(self.data)} reviews')

    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
        review = self.data[idx]
        quote, score = review.quote, review.score
        quote_ = torch.frombuffer(quote.encode('ascii', 'ignore'), dtype=torch.uint8)
        score_ = torch.tensor(score / 5)
        return quote_, score_

In [68]:
dataset_path = "./theater"

dataset = TomatoDataset(dataset_path)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader, test_loader = DataLoader(train_data, batch_size=1), DataLoader(test_data, batch_size=1)

logger.info(f'Train size: {len(train_loader)}')
logger.info(f'Test size: {len(test_loader)}')

# dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
# for i, data in enumerate(dataloader):
#     print(data)
#     if i > 10:
#         break

logger.debug(next(iter(train_loader))[0])

[32m2023-07-20 00:42:31.273[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m36[0m - [1mCounter({5: 201, 3: 201, 4.5: 201, 3.5: 201, 4: 201, 0.5: 201, 2: 201, 1: 201, 2.5: 165, 1.5: 111})[0m
[32m2023-07-20 00:42:31.273[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m37[0m - [1mLoaded 1661 reviews[0m
[32m2023-07-20 00:42:31.274[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mTrain size: 1328[0m
[32m2023-07-20 00:42:31.274[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mTest size: 333[0m
[32m2023-07-20 00:42:31.275[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m21[0m - [34m[1mtensor([[ 84, 121, 112, 105,  99,  97, 108,  32,  87, 101, 115,  32,  65, 110,
         100, 101, 114, 115, 111, 110,  32, 119, 105, 116, 104,  32, 114, 101,
         112, 101,  97, 116, 101, 100,  32,  97, 116, 116, 101, 109, 112, 116,
         115,  32, 116, 111,  32,  98, 101,  

In [69]:
class Net(nn.Module):
    def __init__(self, vocab_size=256, num_hidden=128, num_layers=2):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_hidden = num_hidden
        self.num_layers = num_layers
        self.lstm = nn.LSTM(vocab_size, num_hidden, num_layers)
        self.linear = nn.Linear(num_hidden, 1)
    
    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.float()
        Y, state = self.lstm(X, state)
        output = self.linear(Y)
        output = torch.sigmoid(output)
        return output, state
    
    def begin_state(self, device, batch_size=1):
        return (
            torch.zeros(self.num_layers, batch_size, self.num_hidden, device=device),
            torch.zeros(self.num_layers, batch_size, self.num_hidden, device=device)
        )

In [70]:
vocab_size, num_hidden, num_layers = 256, 128, 2
device = 'cuda:0'

net = Net(vocab_size, num_hidden, num_layers)

for m in net.modules():
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                nn.init.xavier_uniform_(param)


net.to(device)

# for x in net.parameters():
#     print(x.shape)

Net(
  (lstm): LSTM(256, 128, num_layers=2)
  (linear): Linear(in_features=128, out_features=1, bias=True)
)

In [71]:
def grad_clipping(net, theta):
    params = [p for p in net.parameters() if p.requires_grad]
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    # logger.debug(f'grad norm: {norm}')
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

In [72]:
def evaluate(net, val_loader, loss):
    net.eval()
    losses = []
    for X, Y in val_loader:
        state = net.begin_state(device, batch_size=X.shape[0])
        X, Y = X.to(device), Y.to(device)
        y_hat, state = net(X, state)
        final_y_hat = y_hat[-1].reshape(-1)
        l = loss(Y, final_y_hat)
        losses.append(l.item())
    return sum(losses) / len(losses)

def train(net, train_loader, val_loader, lr, num_epochs):
    loss = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    
    for epoch in tqdm(range(num_epochs)):
        net.train()
        for X, Y in train_loader:
            state = net.begin_state(device, batch_size=X.shape[0])
            
            X, Y = X.to(device), Y.to(device)
            y_hat, state = net(X, state)
            # final_y_hat = y_hat[-1].reshape(-1)
            # import pdb; pdb.set_trace()
            Y = Y.repeat(y_hat.shape[0], 1).reshape(-1, 1, 1)
            l = loss(Y, y_hat)
            # logger.debug(Y)
            # logger.debug(f"Y {Y.norm()} y_hat {y_hat.norm()} loss {l.item()}")
            
            optimizer.zero_grad()
            l.backward()
            # logger.info([x.grad for x in optimizer.param_groups[0]['params']])
            # input("pause")
            grad_clipping(net, 1.)
            optimizer.step()
        logger.info(f'Epoch {epoch} loss: {l.item():.6f}')
        val_loss = evaluate(net, val_loader, loss)
        logger.info(f'Epoch {epoch} val loss: {val_loss:.6f}')
        if epoch % 10 == 0:
            torch.save(net.state_dict(), f'./ckpts/model_{epoch}.pt')


In [73]:
train(net, train_loader, test_loader, 1.0, 100)

  0%|          | 0/100 [00:05<?, ?it/s]

[32m2023-07-20 00:42:38.109[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m37[0m - [1mEpoch 0 loss: 1.000000[0m


  1%|          | 1/100 [00:06<10:19,  6.26s/it]

[32m2023-07-20 00:42:38.659[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m39[0m - [1mEpoch 0 val loss: 0.411081[0m


  1%|          | 1/100 [00:11<10:19,  6.26s/it]

[32m2023-07-20 00:42:43.964[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m37[0m - [1mEpoch 1 loss: 1.000000[0m


  2%|▏         | 2/100 [00:11<09:43,  5.95s/it]

[32m2023-07-20 00:42:44.404[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m39[0m - [1mEpoch 1 val loss: 0.411081[0m


  2%|▏         | 2/100 [00:17<09:43,  5.95s/it]

[32m2023-07-20 00:42:50.163[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m37[0m - [1mEpoch 2 loss: 1.000000[0m


  3%|▎         | 3/100 [00:18<09:50,  6.08s/it]

[32m2023-07-20 00:42:50.646[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m39[0m - [1mEpoch 2 val loss: 0.411081[0m


  3%|▎         | 3/100 [00:23<09:50,  6.08s/it]

[32m2023-07-20 00:42:55.654[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m37[0m - [1mEpoch 3 loss: 1.000000[0m


  4%|▍         | 4/100 [00:23<09:23,  5.87s/it]

[32m2023-07-20 00:42:56.199[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m39[0m - [1mEpoch 3 val loss: 0.411081[0m


  4%|▍         | 4/100 [00:29<11:38,  7.27s/it]


KeyboardInterrupt: 