In [1]:
from utils.DevConf import DevConf
devConf = DevConf('cuda')

# Load Data

In [2]:
from datasets import load_dataset

In [3]:
dataset = load_dataset("carblacac/twitter-sentiment-analysis", split="train", trust_remote_code=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [4]:
from transformers import AutoTokenizer
from transformers import BatchEncoding

In [5]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

```python
dataset[0]
```
>{'text': '@fa6ami86 so happy that salman won.  btw the 14sec clip is truely a teaser', 'feeling': 0}

In [6]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

In [7]:
dataset = dataset.map(tokenize_function, batched=True)

dataset[0]
> {'text': '@fa6ami86 so happy that salman won.  btw the 14sec clip is truely a teaser',\
> 'feeling': 0,\
> 'input_ids': [...],\
> 'attention_mask': [...]}

In [8]:
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "feeling"])

dataset[0]
> {'feeling': 0,\
> 'input_ids': [...],\
> 'attention_mask': [...]}

In [9]:
from torch.utils.data import DataLoader

In [10]:
dataloader = DataLoader(dataset, batch_size=16)

# Define Model

In [11]:
from model import SentiDistilBert

In [12]:
myModel = SentiDistilBert(devConf=devConf)

## Test Forward

In [13]:
inputs: BatchEncoding = tokenizer("Hello, my dog is cute", return_tensors="pt").to("cuda")

In [14]:
myModel(**inputs)

tensor([[0.5058, 0.4351]], device='cuda:0', grad_fn=<SigmoidBackward0>)

# Train

In [15]:
def train(model: SentiDistilBert, dataloader):
    model.train()
    model.to()
    for batch in dataloader:
        inputs = {k: v.to(devConf.device) for k, v in batch.items()}
        outputs = model(**inputs)
        loss = outputs.loss
        loss.backward()
        model.optimizer.step()
        model.optimizer.zero_grad()