In [1]:
import torch
import transformers
from tqdm import tqdm
import torch.nn as nn
import numpy as np
from sklearn.metrics import f1_score, accuracy_score
import gc
import json
from time import time
import os
import albumentations as A
import cv2
from dataclasses import dataclass
import torch.nn.functional as F

import sys
sys.path.insert(0, "/home/dzigen/Desktop/ITMO/ВКР/КМУ2024/")

from src.readers.fid import FiDReader
from src.readers.archs.fid_model import FiDT5

In [4]:
t5 = transformers.T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
model = FiDT5(t5.config)

In [None]:
reader = FiDReader()

In [13]:
ids = torch.randint(0, 255, size=(4,1,512))
mask = torch.randint(0, 255, size=(4,1,512))
labels = torch.randint(0, 255, size=(4,512))

In [14]:
model(input_ids=ids, attention_mask=mask, labels=labels)

forward 2
init dimension:  torch.Size([4, 1, 512])
2d resize:  torch.Size([4, 512])
forward 3
2d resize:  torch.Size([4, 512])
---
candidates flat:  torch.Size([4, 512])
encoder output  torch.Size([4, 512, 768])
candidates concatenation:
last_hidden_state torch.Size([4, 512, 768])


In [3]:
input_ids = torch.randint(0, 255, size=(2,2,512))
attention_mask = torch.randint(0, 255, size=(2,2,512))
labels = torch.randint(0, 255, size=(2,512))

In [None]:
out = reader.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=10)

In [11]:
reader.tokenizer.batch_decode(out, skip_special_tokens=True)

['', '']

In [5]:
out.shape

torch.Size([2, 10])

In [4]:
out = reader.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

forward 2
init dimension:  torch.Size([2, 2, 512])
2d resize:  torch.Size([2, 1024])
forward 3
2d resize:  torch.Size([2, 1024])
---
candidates flat:  torch.Size([4, 512])
encoder output  torch.Size([4, 512, 768])
candidates concatenation:
last_hidden_state torch.Size([2, 1024, 768])


### Only Reader Train

### Only Retriever Train

In [86]:
criterion = nn.CrossEntropyLoss()

In [92]:
criterion = nn.CrossEntropyLoss()

output_score = torch.eye(6) #torch.randn(6,6)
targets = torch.arange(0,6)

auto_loss = criterion(output_score, targets)
print(auto_loss)

manual_loss = torch.mean(-torch.log(F.softmax(output_score, dim=1).gather(1, targets.view(-1,1)) ))
print(manual_loss)

tensor(1.0436)
tensor(1.0436)


In [84]:
reader_topk_loss = 0.2
reader_k_loss = torch.tensor([[0.9,0.1,0.5]])
retriever_k_scores = torch.tensor([[5,1,3]])
criterion(reader_topk_loss, reader_k_loss, retriever_k_scores)

tensor(0.0261)

In [83]:
reader_topk_loss = 0.8
reader_k_loss = torch.tensor([[0.9,0.8,1]])
retriever_k_scores = torch.tensor([[5,3,3]])
criterion(reader_topk_loss, reader_k_loss, retriever_k_scores)

tensor(0.6946)

In [95]:
torch.tensor([[1,2,3],[4,5,6]]).argmax(dim=1)

tensor([2, 2])

### Reader + Frozen Retriever Train

In [7]:
import torch

In [12]:
input_ids = torch.randint(0,255, size=(1,2,512))
attention_mask = torch.randint(0,2, size=(1,2,512))

labels = torch.randint(0,255, size=(1,512))

In [25]:
import torch.nn as nn

In [21]:
reader.tokenizer(" 2 hello world", add_special_tokens=True)

{'input_ids': [204, 21820, 296, 1], 'attention_mask': [1, 1, 1, 1]}

In [24]:
output = reader.model(input_ids=input_ids,attention_mask=attention_mask, labels=labels)

forward 2
init dimension:  torch.Size([1, 2, 512])
2d resize:  torch.Size([1, 1024])
forward 3
2d resize:  torch.Size([1, 1024])
---
candidates flat:  torch.Size([2, 512])
encoder output  torch.Size([2, 512, 768])
candidates concatenation:
last_hidden_state torch.Size([1, 1024, 768])


In [59]:
output.logits.shape

torch.Size([1, 512, 32128])

In [60]:
labels.shape

torch.Size([1, 512])

In [111]:
F.softmax(output.logits, dim=-1)[0][0][111]

tensor(1.2608e-06, grad_fn=<SelectBackward0>)

In [106]:
F.softmax(output.logits.view(512, 32128), dim=-1).gather(1,labels)[0][0]

tensor(1.2608e-06, grad_fn=<SelectBackward0>)

In [137]:
output.loss

tensor(15.1423, grad_fn=<NllLossBackward0>)

In [135]:
outs = torch.mean(-torch.log(F.softmax(output.logits, dim=-1).gather(2,labels.view(1, 512, -1))).view(1, 512), dim=1)

In [136]:
outs

tensor([15.1423], grad_fn=<MeanBackward1>)

In [124]:
F.softmax(output.logits, dim=-1)[0][0][111]

tensor(1.2608e-06, grad_fn=<SelectBackward0>)

In [3]:
reader = FiDReader()

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on google-t5/t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
reader.tokenizer(["hello", "world"], max_length=512, padding='max_length', 
            return_tensors='pt', truncation=True)

{'input_ids': tensor([[21820,     1,     0,  ...,     0,     0,     0],
        [  296,     1,     0,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 0,  ..., 0, 0, 0],
        [1, 1, 0,  ..., 0, 0, 0]])}

In [7]:
encoding = reader.tokenizer("hello", "world", return_tensors="pt")

In [10]:
reader.tokenizer.eos_token

'</s>'

In [None]:
output_reshape = output[:,:-1].contiguous().view(-1, output.shape[-1])
                trg_batch = trg_batch[:, 1:].contiguous().view(-1)

### Joint Reader + Retriever Train

In [None]:
class JoinLoss:
    def __init__(self, r=1) -> None:
        self.temp = r

    def __call__(self, reader_topk_loss, reader_k_loss, retriever_k_scores):
        '''
        params:
            reader_topk_loss: 1
            reader_k_loss: BxN
            retriever_k_scores: BxN

        output:
            scores: 1
        '''

        retriever_part = torch.mean(torch.log(torch.sum(
            F.softmax(retriever_k_scores / self.temp, dim=1)*reader_k_loss, dim=1)))

        return reader_topk_loss + retriever_part
    
    def k_loss(self, reader_logits, labels):
        '''
        params:
            reader_logits: BxNxLxVOCAB_SIZE
            labels: BxL

        output:
            scores: BxN
        '''
        bsz, k, seq_len = reader_logits.shape[0], reader_logits.shape[1], reader_logits.shape[2]

        return torch.mean(-torch.log(F.softmax(
            reader_logits.logits, dim=-1).gather(3,labels.view(bsz, 1, seq_len, -1))).view(bsz, k, seq_len), dim=2)
