In [1]:
def load_data():
    import pandas as pd
    data = pd.read_csv('data/sst2/data.csv')
    
    return data

data = load_data()
data.head()

Unnamed: 0,x,y
0,"101,5342,2047,3595,8496,2013,1996,18643,3197,1...",0
1,"101,3397,2053,15966,1010,2069,4450,2098,18201,...",0
2,"101,2008,7459,2049,3494,1998,10639,2015,2242,2...",1
3,"101,3464,12580,8510,2000,3961,1996,2168,2802,1...",0
4,"101,2006,1996,5409,7195,1011,1997,1011,1996,10...",0


In [2]:
import torch
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x, y = self.data.iloc[index]
        
        x = [int(i) for i in x.split(",")]
        x = torch.LongTensor(x)
        
        y = int(y)
        
        return x, y
    
dataset = Dataset(data)
len(dataset), dataset[0]

(65000,
 (tensor([  101,  5342,  2047,  3595,  8496,  2013,  1996, 18643,  3197,   102,
              0,     0,     0,     0,     0]),
  0))

In [3]:
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)
len(loader), next(iter(loader))

(2032,
 [tensor([[  101,  1037,  2132,  1011,  6769,  1011,  1011, 19897,  2517,   102,
               0,     0,     0,     0,     0],
          [  101, 21864, 15952,  7726,  3689,  3595, 10428,   102,     0,     0,
               0,     0,     0,     0,     0],
          [  101,  1005,  1055,  2525,  1037,  8257,  1999,  1996,  2142,  2163,
            1012,   102,     0,     0,     0],
          [  101,  2000,  2031,  6404,  2673,  2002,  2412,  2354,  2055, 11717,
           23873,   102,     0,     0,     0],
          [  101,  2009, 18276,  2003,  2008,  2009,  1005,  1055,  2036,  2028,
            1997,  1996,  6047,  4355,   102],
          [  101,  1037,  2092,  1011, 10849,  5891,  2004,  2028,   102,     0,
               0,     0,     0,     0,     0],
          [  101,  2008, 13695,  2630, 10188,  2046,  2028,  1997,  1996,  2621,
            1005,  1055,  2087, 22512,   102],
          [  101,  1037,  2034,  1011,  2465,  1010, 12246,  5994,  1038,  3185,
            2008

In [5]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embedding = torch.nn.Embedding(30522, 128)
        self.rnn = torch.nn.LSTM(128, 128, batch_first=True)
        self.fc = torch.nn.Linear(128, 2)
        
    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.rnn(x)
        x = x[:, -1]
        x = self.fc(x)
        
        return x
    
model = Model()
model(torch.ones(8, 15).long()).shape

torch.Size([8, 2])

In [7]:
class train():
    def __init__(self, model, loader):
        self.model = model
        self.loader = loader
        
        self.criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
    def __call__(self, epoch):
        self.model.train()
        
        for batch, (x, y) in enumerate(self.loader):
            self.optimizer.zero_grad()
            
            y_hat = self.model(x)
            loss = self.criterion(y_hat, y)
            
            loss.backward()
            self.optimizer.step()
            
            if batch % 100 == 0:
                print(f"Epoch: {epoch}, Batch: {batch}, Loss: {loss.item()}")
                
        torch.save(model, 'model/7.model')

train(model, loader)(0)

Epoch: 0, Batch: 0, Loss: 0.69573974609375
Epoch: 0, Batch: 100, Loss: 0.6915282607078552
Epoch: 0, Batch: 200, Loss: 0.6666579246520996
Epoch: 0, Batch: 300, Loss: 0.6574747562408447
Epoch: 0, Batch: 400, Loss: 0.5668902397155762
Epoch: 0, Batch: 500, Loss: 0.5231608152389526
Epoch: 0, Batch: 600, Loss: 0.49259087443351746
Epoch: 0, Batch: 700, Loss: 0.4962180256843567
Epoch: 0, Batch: 800, Loss: 0.46396827697753906
Epoch: 0, Batch: 900, Loss: 0.4237833321094513
Epoch: 0, Batch: 1000, Loss: 0.34554892778396606
Epoch: 0, Batch: 1100, Loss: 0.5370961427688599
Epoch: 0, Batch: 1200, Loss: 0.49347299337387085
Epoch: 0, Batch: 1300, Loss: 0.5599779486656189
Epoch: 0, Batch: 1400, Loss: 0.347579687833786
Epoch: 0, Batch: 1500, Loss: 0.1854720562696457
Epoch: 0, Batch: 1600, Loss: 0.46506625413894653
Epoch: 0, Batch: 1700, Loss: 0.4102201759815216
Epoch: 0, Batch: 1800, Loss: 0.30582094192504883
Epoch: 0, Batch: 1900, Loss: 0.30795541405677795
Epoch: 0, Batch: 2000, Loss: 0.3772348165512085


In [8]:
@torch.no_grad()
def test():
    model = torch.load('model/7.model')
    model.eval()

    correct = 0
    total = 0
    for i in range(100):
        x, y = next(iter(loader))

        out = model(x).argmax(dim=1)

        correct += (out == y).sum().item()
        total += len(y)

    print(correct / total)


test()

0.89
