In [10]:
#!pip install torch
#!pip install transformers

In [11]:
import torch
from transformers import AutoModel, AutoTokenizer
from load_data import load_abstracts, load_concepts
from torch.utils.data import Dataset, DataLoader
import numpy as np
import re

device = "cuda" if torch.cuda.is_available() else "cpu"

class AbstractDataset(Dataset):
    def __init__(self, abstracts, concepts):
        self.abstracts = {}
        self.concepts = {}

        for abstract_id in abstracts:
            if abstract_id in concepts:
                self.abstracts[abstract_id] = abstracts[abstract_id]
                self.concepts[abstract_id] = concepts[abstract_id]
        self.indices = list(self.abstracts.keys())

    def __len__(self):
        return len(self.abstracts)

    def __getitem__(self, idx):
        return {
            'abstract': self.abstracts[self.indices[idx]],
            'concepts': self.concepts[self.indices[idx]]
        }
    
def collate_fn(batch, tokenizer, max_length = 512):
    all_abstracts = []
    all_labels = []

    for item in batch: 
        abstract = item["abstract"]
        all_abstracts.append(abstract)
        concepts = item["concepts"]

        tokenizer_out = tokenizer(abstract, add_special_tokens = False)
        gt_array = np.zeros(len(tokenizer_out.tokens()))

        for concept in concepts:
            matches = [match.start() for match in re.finditer(concept.lower(), abstract.lower())]
            spans = [(s, s+len(concept)) for s in matches]

            for span in spans:
                token_indices = list(sorted([tokenizer_out.char_to_token(index) for index in range(*span)]))
                gt_array[token_indices[0]] = 1
                for other_index in token_indices[1:]:
                    gt_array[other_index] = 2

        all_labels.append(gt_array)

    inputs = tokenizer(all_abstracts, max_length=max_length, truncation=True, padding=True, return_tensors="pt").to(device)
    labels = torch.tensor(all_labels, dtype=torch.float, device=device)

    max_length = max(len(arr) for arr in masks)
    batch_masks = np.array([np.pad(arr, (0, max_length - len(arr)), mode='constant') for arr in masks])
    masks = torch.tensor(batch_masks[:, :max_length-2], dtype=torch.float32, device=device)

    return inputs, masks


def train():
    model = AutoModel.from_pretrained("microsoft/deberta-base")
    tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-base")
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)
    

    abstracts = load_abstracts()
    concepts = load_concepts()

    dataset = AbstractDataset(abstracts, concepts)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda b: collate_fn(b, tokenizer))

    for epoch in range(30):
        for batch in train_loader:
            print(batch)

train()

100%|██████████| 37786/37786 [00:03<00:00, 11062.01it/s]
100%|██████████| 409/409 [00:00<00:00, 14736.45it/s]


KeyError: 'abstracts'