In [1]:
import torch

dataset = torch.rand(5000, 13, 768)
labels = torch.randint(0, 2, (5000, ))
embedding_size = 768
top_k = 10
dataset.shape

torch.Size([5000, 13, 768])

In [2]:
from sklearn.model_selection import train_test_split

train, test, train_labels, test_labels = train_test_split(dataset, labels, test_size=0.2, random_state=42)
train.shape, test.shape

(torch.Size([4000, 13, 768]), torch.Size([1000, 13, 768]))

In [3]:
from torch.utils.data import DataLoader, Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx][0].clone().detach().cuda(), self.data[idx][1].clone().detach().cuda(), self.data[idx][2].clone().detach().cuda(), self.data[idx][3:].clone().detach().cuda(), self.labels[idx].clone().detach().cuda()

train_dataset = CustomDataset(train, train_labels)
test_dataset = CustomDataset(test, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

train_dataset.__getitem__(0)[0].shape, train_dataset.__getitem__(0)[1].shape, train_dataset.__getitem__(0)[2].shape, train_dataset.__getitem__(0)[3].shape, train_dataset.__getitem__(0)[4]

(torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([10, 768]),
 tensor(0, device='cuda:0'))

In [None]:
from torch import nn
from torch.functional import F

class Model(nn.Module):
    def __init__(self, embedding_size=768):
        super(Model, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(embedding_size*3, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        
    def forward(self, text, candidate_title, candidate_abstract, text_citation):
        concat_text_citation = torch.cat([text.unsqueeze(1), text_citation], dim=1)

        att_title_citation = torch.bmm(F.normalize(concat_text_citation, p=2, dim=2), F.normalize(candidate_title.unsqueeze(1).permute(0, 2, 1), p=2, dim=2))
        att_title_citation_value = nn.functional.softmax(att_title_citation, dim=1)
        title_citation_after_attention = (concat_text_citation * att_title_citation_value).sum(dim=1)

        att_abstract_citation = torch.bmm(F.normalize(concat_text_citation, p=2, dim=2), F.normalize(candidate_abstract.unsqueeze(1).permute(0, 2, 1), p=2, dim=2))
        att_abstract_citation_value = nn.functional.softmax(att_abstract_citation, dim=1)
        abstract_citation_after_attention = (concat_text_citation * att_abstract_citation_value).sum(dim=1)


        x = torch.cat([text, title_citation_after_attention, abstract_citation_after_attention], dim=1)
        return self.encoder(x)
    
text = torch.randint(0, 5, (2, 3)).float().to('cuda')
candidate_title = torch.randint(0, 5, (2, 3)).float().to('cuda')
candidate_abstract = torch.randint(0, 5, (2, 3)).float().to('cuda')
text_citation = torch.randint(0, 5, (2, 3, 3)).float().to('cuda')
model = Model(embedding_size=text.shape[1]).to('cuda')
model(text, candidate_title, candidate_abstract, text_citation)

torch.Size([2, 4, 1])


tensor([[0.5149],
        [0.5149]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [23]:
def train_epoch(model, train_loader, criterion, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    for text, candidate_title, candidate_abstract, text_citation, labels in train_loader:
        optimizer.zero_grad()
        output = model(text, candidate_title, candidate_abstract, text_citation)
        loss = criterion(output, labels.unsqueeze(1).to(torch.float32))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        correct += (output > 0.5).eq(labels.unsqueeze(1)).sum().item()
    return train_loss / len(train_loader), correct / len(train_loader.dataset)

def test_epoch(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for text, candidate_title, candidate_abstract, text_citation, labels in test_loader:
            output = model(text, candidate_title, candidate_abstract, text_citation)
            loss = criterion(output, labels.unsqueeze(1).to(torch.float32))
            test_loss += loss.item()
            correct += (output > 0.5).eq(labels.unsqueeze(1)).sum().item()
    return test_loss / len(test_loader), correct / len(test_loader.dataset)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = Model(embedding_size=next(iter(train_loader))[0].shape[1]).to('cuda')

for epoch in range(10):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    test_loss, test_acc = test_epoch(model, test_loader, criterion)

    print(f'Epoch {epoch+1}')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

Epoch 1
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 2
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 3
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 4
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 5
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 6
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 7
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 8
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 9
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
Epoch 10
Train Loss: 0.6937, Train Acc: 0.4915
Test Loss: 0.6932, Test Acc: 0.5050
