In [1]:
# !pip install sentencepiece
# !pip install transformers
# !pip install rich[jupyter]

In [2]:
# translate_template_de_en = lambda x:  "Translate '%s' from german to english"%x

In [3]:
import pandas as pd
from tqdm import tqdm
# df = pd.read_csv("https://raw.githubusercontent.com/Shivanandroy/T5-Finetuning-PyTorch/main/data/news_summary.csv")
# WMT16 dataset
from datasets import load_dataset, DownloadMode, load_metric

train_df = load_dataset("wmt16", "de-en", split="train[:7000]", num_proc=8)
# eval_df = load_dataset("wmt16", "de-en", split="validation")

# Create data frame
df = pd.DataFrame({
    "source_text":[i['translation']['en'] for i in tqdm(train_df)],
    "target_text":[i['translation']['de'] for i in tqdm(train_df)]
}).applymap(str)

# eval_df_ds = pd.DataFrame({
#     "source_text":[i['translation']['en'] for i in eval_df],
#     "target_text":[i['translation']['de'] for i in eval_df]
# })
# eval_df_ds = eval_df_ds.applymap(str)

Found cached dataset wmt16 (/home/agupt135/.cache/huggingface/datasets/wmt16/de-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227)
100%|██████████████████| 7000/7000 [00:00<00:00, 11489.84it/s]
100%|██████████████████| 7000/7000 [00:00<00:00, 11609.14it/s]


In [4]:
# df.to_csv("de_en_translation.csv")

In [5]:
# df = train_df_ds

In [6]:
# type(df)

In [7]:
# import pandas as pd
# df = pd.read_csv("de_en_translation.csv")
# df = df.iloc[:,1:]

In [8]:
df.head()

Unnamed: 0,source_text,target_text
0,Resumption of the session,Wiederaufnahme der Sitzungsperiode
1,I declare resumed the session of the European ...,"Ich erkläre die am Freitag, dem 17. Dezember u..."
2,"Although, as you will have seen, the dreaded '...","Wie Sie feststellen konnten, ist der gefürchte..."
3,You have requested a debate on this subject in...,Im Parlament besteht der Wunsch nach einer Aus...
4,"In the meantime, I should like to observe a mi...",Heute möchte ich Sie bitten - das ist auch der...


In [9]:
print(df.get("source_text")[0],"\n",df.get("target_text")[0],"\n",df.get("source_text")[1],"\n",df.get("target_text")[1])

Resumption of the session 
 Wiederaufnahme der Sitzungsperiode 
 I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period. 
 Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten.


In [10]:
# Importing libraries
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import os
import gc
from GPUtil import showUtilization as gpu_usage
from numba import cuda

# Importing the T5 modules from huggingface/transformers
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline

from rich.table import Column, Table
from rich import box
from rich.console import Console

# define a rich console logger
console=Console(record=True)

def display_df(df):
  """display dataframe in ASCII format"""

  console=Console()
  table = Table(Column("input_text", justify="center" ), Column("source_text", justify="center"), title="Translation Data",pad_edge=False, box=box.ASCII)

  for i, row in enumerate(df.values.tolist()):
    table.add_row(row[0], row[1])

  console.print(table)

training_logger = Table(Column("Epoch", justify="center" ), 
                        Column("Steps", justify="center"),
                        Column("Loss", justify="center"), 
                        title="Training Status",pad_edge=False, box=box.ASCII)


2023-03-25 07:31:09.436413: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-25 07:31:09.436635: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [11]:
# Setting up the device for GPU usage
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print("device : ",device)

device :  cuda


In [12]:
class YourDataSetClass(Dataset):
  """
  Creating a custom dataset for reading the dataset and 
  loading it into the dataloader to pass it to the neural network for finetuning the model

  """

  def __init__(self, dataframe, tokenizer, source_len, target_len, source_text, target_text):
    self.tokenizer = tokenizer
    self.data = dataframe
    self.source_len = source_len
    self.summ_len = target_len
    self.target_text = self.data[target_text]
    self.source_text = self.data[source_text]

  def __len__(self):
    return len(self.target_text)

  def __getitem__(self, index):
    source_text = str(self.source_text[index])
    target_text = str(self.target_text[index])

    #cleaning data so as to ensure data is in string type
    source_text = ' '.join(source_text.split())
    target_text = ' '.join(target_text.split())

    source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
    target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')

    source_ids = source['input_ids'].squeeze()
    source_mask = source['attention_mask'].squeeze()
    target_ids = target['input_ids'].squeeze()
    target_mask = target['attention_mask'].squeeze()

    return {
        'source_ids': source_ids.to(dtype=torch.long), 
        'source_mask': source_mask.to(dtype=torch.long), 
        'target_ids': target_ids.to(dtype=torch.long),
        'target_ids_y': target_ids.to(dtype=torch.long)
    }

In [13]:
def train(epoch, tokenizer, model, device, loader, optimizer):

  """
  Function to be called for training with the parameters passed from main function

  """

  model.train()
  for _,data in tqdm(enumerate(loader, 0)):
    y = data['target_ids'].to(device, dtype = torch.long)
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone().detach()
    lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
    ids = data['source_ids'].to(device, dtype = torch.long)
    mask = data['source_mask'].to(device, dtype = torch.long)
    
    torch.cuda.empty_cache()
    outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
    loss = outputs[0]

    if _%10==0:
      training_logger.add_row(str(epoch), str(_), str(loss))
      console.print(training_logger)

    optimizer.zero_grad()
    torch.cuda.empty_cache()
    loss.backward()
    optimizer.step()

In [14]:
def validate(epoch, tokenizer, model, device, loader):

  """
  Function to evaluate model for predictions

  """
  model.eval()
  predictions = []
  actuals = []
  with torch.no_grad():
      for _, data in tqdm(enumerate(loader, 0)):
          y = data['target_ids'].to(device, dtype = torch.long)
          ids = data['source_ids'].to(device, dtype = torch.long)
          mask = data['source_mask'].to(device, dtype = torch.long)

          generated_ids = model.generate(
              input_ids = ids,
              attention_mask = mask, 
              max_length=150, 
              num_beams=2,
              repetition_penalty=2.5, 
              length_penalty=1.0, 
              early_stopping=True
              )
          preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
          target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
          if _%10==0:
              console.print(f'Completed {_}')

          predictions.extend(preds)
          actuals.extend(target)
  return predictions, actuals

In [15]:
def T5Trainer(dataframe, source_text, target_text, model_params, output_dir="./outputs/en-de-translation/" ):
  
  """
  T5 trainer

  """

  # Set random seeds and deterministic pytorch for reproducibility
  torch.manual_seed(model_params["SEED"]) # pytorch random seed
  np.random.seed(model_params["SEED"]) # numpy random seed
  torch.backends.cudnn.deterministic = True

  # logging
  console.log(f"""[Model]: Loading {model_params["MODEL"]}...\n""")

  # tokenzier for encoding the text
  tokenizer = T5Tokenizer.from_pretrained(model_params["MODEL"])

  # Defining the model. We are using t5-base model and added a Language model layer on top for generation of Summary. 
  # Further this model is sent to device (GPU/TPU) for using the hardware.
  gc.collect()
  model = T5ForConditionalGeneration.from_pretrained(model_params["MODEL"])
  model = model.to(device)
  
  # logging
  console.log(f"[Data]: Reading data...\n")

  # Importing the raw dataset
  dataframe = dataframe[[source_text,target_text]]
  display_df(dataframe.head(2))

  
  # Creation of Dataset and Dataloader
  # Defining the train size. So 80% of the data will be used for training and the rest for validation. 
  train_size = 0.8
  train_dataset=dataframe.sample(frac=train_size,random_state = model_params["SEED"])
  val_dataset=dataframe.drop(train_dataset.index).reset_index(drop=True)
  train_dataset = train_dataset.reset_index(drop=True)

  console.print(f"FULL Dataset: {dataframe.shape}")
  console.print(f"TRAIN Dataset: {train_dataset.shape}")
  console.print(f"TEST Dataset: {val_dataset.shape}\n")


  # Creating the Training and Validation dataset for further creation of Dataloader
  training_set = YourDataSetClass(train_dataset, tokenizer, model_params["MAX_SOURCE_TEXT_LENGTH"], model_params["MAX_TARGET_TEXT_LENGTH"], source_text, target_text)
  val_set = YourDataSetClass(val_dataset, tokenizer, model_params["MAX_SOURCE_TEXT_LENGTH"], model_params["MAX_TARGET_TEXT_LENGTH"], source_text, target_text)


  # Defining the parameters for creation of dataloaders
  train_params = {
      'batch_size': model_params["TRAIN_BATCH_SIZE"],
      'shuffle': True,
      'num_workers': 0
      }


  val_params = {
      'batch_size': model_params["VALID_BATCH_SIZE"],
      'shuffle': False,
      'num_workers': 0
      }


  # Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model.
  training_loader = DataLoader(training_set, **train_params)
  val_loader = DataLoader(val_set, **val_params)


  # Defining the optimizer that will be used to tune the weights of the network in the training session. 
  optimizer = torch.optim.Adam(params =  model.parameters(), lr=model_params["LEARNING_RATE"])


  # Training loop
  console.log(f'[Initiating Fine Tuning]...\n')

  for epoch in tqdm(range(model_params["TRAIN_EPOCHS"])):
    train(epoch, tokenizer, model, device, training_loader, optimizer)
      
  console.log(f"[Saving Model]...\n")
  #Saving the model after training
  path = os.path.join(output_dir, "model_files")
  model.save_pretrained(path)
  tokenizer.save_pretrained(path)


  # evaluating test dataset
  console.log(f"[Initiating Validation]...\n")
  for epoch in tqdm(range(model_params["VAL_EPOCHS"])):
    predictions, actuals = validate(epoch, tokenizer, model, device, val_loader)
    final_df = pd.DataFrame({'Generated Text':predictions,'Actual Text':actuals})
    final_df.to_csv(os.path.join(output_dir,'predictions.csv'))
  
  console.save_text(os.path.join(output_dir,'logs.txt'))
  
  console.log(f"[Validation Completed.]\n")
  console.print(f"""[Model] Model saved @ {os.path.join(output_dir, "model_files")}\n""")
  console.print(f"""[Validation] Generation on Validation data saved @ {os.path.join(output_dir,'predictions.csv')}\n""")
  console.print(f"""[Logs] Logs saved @ {os.path.join(output_dir,'logs.txt')}\n""")
  text = 'Resumption of the session'
  inference(trained_model=model, trained_tokenizer=tokenizer, text=text)

In [16]:
model_params={
    "MODEL":"t5-small",             # model_type: t5-base/t5-large
    "TRAIN_BATCH_SIZE":64,          # training batch size
    "VALID_BATCH_SIZE":64,          # validation batch size
    "TRAIN_EPOCHS":4,              # number of training epochs
    "VAL_EPOCHS":5,                # number of validation epochs
    "LEARNING_RATE":1e-4,          # learning rate
    "MAX_SOURCE_TEXT_LENGTH":120,  # max length of source text
    "MAX_TARGET_TEXT_LENGTH":120,   # max length of target text
    "SEED": 42,                     # set seed for reproducibility 
    "requires_grad":True
}

In [17]:
def inference(trained_model, trained_tokenizer, text):
    generator = pipeline(
        "text2text-generation",
        model=trained_model,
        tokenizer=trained_tokenizer,
        device=0 if torch.cuda.is_available() else -1, # use GPU if available
    )

    input_text = text # input text for generation
    generated_text = generator(input_text, max_length=50)
    print(generated_text)

In [18]:
T5Trainer(dataframe=df, 
          source_text="source_text", 
          target_text="target_text", 
          model_params=model_params, 
          output_dir="./FAMNet/HoSoTAs/en-de-new")

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


  0%|                                   | 0/4 [00:00<?, ?it/s]
0it [00:00, ?it/s][A


1it [00:01,  1.36s/it][A
2it [00:02,  1.17s/it][A
3it [00:03,  1.11s/it][A
4it [00:04,  1.08s/it][A
5it [00:05,  1.07s/it][A
6it [00:06,  1.06s/it][A
7it [00:07,  1.05s/it][A
8it [00:08,  1.05s/it][A
9it [00:09,  1.05s/it][A
10it [00:10,  1.05s/it][A


11it [00:11,  1.05s/it][A
12it [00:12,  1.05s/it][A
13it [00:13,  1.05s/it][A
14it [00:14,  1.04s/it][A
15it [00:15,  1.04s/it][A
16it [00:17,  1.05s/it][A
17it [00:18,  1.04s/it][A
18it [00:19,  1.04s/it][A
19it [00:20,  1.04s/it][A
20it [00:21,  1.03s/it][A


21it [00:22,  1.04s/it][A
22it [00:23,  1.03s/it][A
23it [00:24,  1.03s/it][A
24it [00:25,  1.04s/it][A
25it [00:26,  1.03s/it][A
26it [00:27,  1.04s/it][A
27it [00:28,  1.03s/it][A
28it [00:29,  1.03s/it][A
29it [00:30,  1.03s/it][A
30it [00:31,  1.03s/it][A


31it [00:32,  1.04s/it][A
32it [00:33,  1.04s/it][A
33it [00:34,  1.04s/it][A
34it [00:35,  1.04s/it][A
35it [00:36,  1.04s/it][A
36it [00:37,  1.03s/it][A
37it [00:38,  1.03s/it][A
38it [00:39,  1.03s/it][A
39it [00:40,  1.03s/it][A
40it [00:41,  1.03s/it][A


41it [00:42,  1.03s/it][A
42it [00:43,  1.03s/it][A
43it [00:44,  1.03s/it][A
44it [00:45,  1.04s/it][A
45it [00:47,  1.04s/it][A
46it [00:48,  1.04s/it][A
47it [00:49,  1.04s/it][A
48it [00:50,  1.03s/it][A
49it [00:51,  1.03s/it][A
50it [00:52,  1.03s/it][A


51it [00:53,  1.04s/it][A
52it [00:54,  1.03s/it][A
53it [00:55,  1.03s/it][A
54it [00:56,  1.03s/it][A
55it [00:57,  1.03s/it][A
56it [00:58,  1.03s/it][A
57it [00:59,  1.03s/it][A
58it [01:00,  1.04s/it][A
59it [01:01,  1.04s/it][A
60it [01:02,  1.04s/it][A


61it [01:03,  1.04s/it][A
62it [01:04,  1.04s/it][A
63it [01:05,  1.04s/it][A
64it [01:06,  1.05s/it][A
65it [01:07,  1.04s/it][A
66it [01:08,  1.04s/it][A
67it [01:09,  1.05s/it][A
68it [01:10,  1.05s/it][A
69it [01:11,  1.04s/it][A
70it [01:12,  1.04s/it][A


71it [01:14,  1.05s/it][A
72it [01:15,  1.04s/it][A
73it [01:16,  1.05s/it][A
74it [01:17,  1.05s/it][A
75it [01:18,  1.05s/it][A
76it [01:19,  1.05s/it][A
77it [01:20,  1.05s/it][A
78it [01:21,  1.04s/it][A
79it [01:22,  1.04s/it][A
80it [01:23,  1.04s/it][A


81it [01:24,  1.04s/it][A
82it [01:25,  1.04s/it][A
83it [01:26,  1.04s/it][A
84it [01:27,  1.04s/it][A
85it [01:28,  1.04s/it][A
86it [01:29,  1.04s/it][A
87it [01:30,  1.04s/it][A
88it [01:31,  1.04s/it][A
 25%|██████▊                    | 1/4 [01:31<04:34, 91.48s/it]
0it [00:00, ?it/s][A


1it [00:00,  1.14it/s][A
2it [00:01,  1.02it/s][A
3it [00:02,  1.01s/it][A
4it [00:03,  1.02s/it][A
5it [00:05,  1.02s/it][A
6it [00:06,  1.03s/it][A
7it [00:07,  1.03s/it][A
8it [00:08,  1.03s/it][A
9it [00:09,  1.03s/it][A
10it [00:10,  1.03s/it][A


11it [00:11,  1.04s/it][A
12it [00:12,  1.04s/it][A
13it [00:13,  1.03s/it][A
14it [00:14,  1.03s/it][A
15it [00:15,  1.03s/it][A
16it [00:16,  1.02s/it][A
17it [00:17,  1.03s/it][A
18it [00:18,  1.03s/it][A
19it [00:19,  1.03s/it][A
20it [00:20,  1.03s/it][A


21it [00:21,  1.05s/it][A
22it [00:22,  1.04s/it][A
23it [00:23,  1.04s/it][A
24it [00:24,  1.04s/it][A
25it [00:25,  1.04s/it][A
26it [00:26,  1.04s/it][A
27it [00:27,  1.04s/it][A
28it [00:28,  1.04s/it][A
29it [00:29,  1.05s/it][A
30it [00:30,  1.05s/it][A


31it [00:32,  1.06s/it][A
32it [00:33,  1.05s/it][A
33it [00:34,  1.05s/it][A
34it [00:35,  1.05s/it][A
35it [00:36,  1.05s/it][A
36it [00:37,  1.05s/it][A
37it [00:38,  1.04s/it][A
38it [00:39,  1.04s/it][A
39it [00:40,  1.04s/it][A
40it [00:41,  1.04s/it][A


41it [00:42,  1.04s/it][A
42it [00:43,  1.04s/it][A
43it [00:44,  1.03s/it][A
44it [00:45,  1.04s/it][A
45it [00:46,  1.04s/it][A
46it [00:47,  1.04s/it][A
47it [00:48,  1.04s/it][A
48it [00:49,  1.04s/it][A
49it [00:50,  1.04s/it][A
50it [00:51,  1.04s/it][A


51it [00:52,  1.04s/it][A
52it [00:53,  1.04s/it][A
53it [00:54,  1.04s/it][A
54it [00:55,  1.04s/it][A
55it [00:57,  1.04s/it][A
56it [00:58,  1.04s/it][A
57it [00:59,  1.04s/it][A
58it [01:00,  1.03s/it][A
59it [01:01,  1.03s/it][A
60it [01:02,  1.04s/it][A


61it [01:03,  1.05s/it][A
62it [01:04,  1.04s/it][A
63it [01:05,  1.03s/it][A
64it [01:06,  1.04s/it][A
65it [01:07,  1.04s/it][A
66it [01:08,  1.04s/it][A
67it [01:09,  1.04s/it][A
68it [01:10,  1.04s/it][A
69it [01:11,  1.03s/it][A
70it [01:12,  1.03s/it][A


71it [01:13,  1.04s/it][A
72it [01:14,  1.04s/it][A
73it [01:15,  1.04s/it][A
74it [01:16,  1.03s/it][A
75it [01:17,  1.03s/it][A
76it [01:18,  1.04s/it][A
77it [01:19,  1.04s/it][A
78it [01:20,  1.04s/it][A
79it [01:21,  1.04s/it][A
80it [01:22,  1.03s/it][A


81it [01:23,  1.04s/it][A
82it [01:24,  1.04s/it][A
83it [01:26,  1.04s/it][A
84it [01:27,  1.03s/it][A
85it [01:28,  1.04s/it][A
86it [01:29,  1.04s/it][A
87it [01:30,  1.04s/it][A
88it [01:30,  1.03s/it][A
 50%|█████████████▌             | 2/4 [03:02<03:02, 91.16s/it]
0it [00:00, ?it/s][A


1it [00:00,  1.17it/s][A
2it [00:01,  1.04it/s][A
3it [00:02,  1.00s/it][A
4it [00:03,  1.01s/it][A
5it [00:04,  1.01s/it][A
6it [00:06,  1.02s/it][A
7it [00:07,  1.03s/it][A
8it [00:08,  1.03s/it][A
9it [00:09,  1.04s/it][A
10it [00:10,  1.04s/it][A


11it [00:11,  1.05s/it][A
12it [00:12,  1.05s/it][A
13it [00:13,  1.05s/it][A
14it [00:14,  1.04s/it][A
15it [00:15,  1.04s/it][A
16it [00:16,  1.04s/it][A
17it [00:17,  1.04s/it][A
18it [00:18,  1.05s/it][A
19it [00:19,  1.05s/it][A
20it [00:20,  1.05s/it][A


21it [00:21,  1.05s/it][A
22it [00:22,  1.05s/it][A
23it [00:23,  1.04s/it][A
24it [00:24,  1.05s/it][A
25it [00:25,  1.05s/it][A
26it [00:26,  1.05s/it][A
27it [00:27,  1.04s/it][A
28it [00:29,  1.05s/it][A
29it [00:30,  1.04s/it][A
30it [00:31,  1.05s/it][A


31it [00:32,  1.06s/it][A
32it [00:33,  1.05s/it][A
33it [00:34,  1.04s/it][A
34it [00:35,  1.04s/it][A
35it [00:36,  1.04s/it][A
36it [00:37,  1.04s/it][A
37it [00:38,  1.04s/it][A
38it [00:39,  1.04s/it][A
39it [00:40,  1.04s/it][A
40it [00:41,  1.04s/it][A


41it [00:42,  1.04s/it][A
42it [00:43,  1.05s/it][A
43it [00:44,  1.04s/it][A
44it [00:45,  1.05s/it][A
45it [00:46,  1.04s/it][A
46it [00:47,  1.04s/it][A
47it [00:48,  1.04s/it][A
48it [00:49,  1.04s/it][A
49it [00:50,  1.05s/it][A
50it [00:51,  1.04s/it][A


51it [00:53,  1.05s/it][A
52it [00:54,  1.05s/it][A
53it [00:55,  1.05s/it][A
54it [00:56,  1.04s/it][A
55it [00:57,  1.04s/it][A
56it [00:58,  1.05s/it][A
57it [00:59,  1.05s/it][A
58it [01:00,  1.05s/it][A
59it [01:01,  1.04s/it][A
60it [01:02,  1.04s/it][A


61it [01:03,  1.05s/it][A
62it [01:04,  1.04s/it][A
63it [01:05,  1.04s/it][A
64it [01:06,  1.04s/it][A
65it [01:07,  1.04s/it][A
66it [01:08,  1.04s/it][A
67it [01:09,  1.04s/it][A
68it [01:10,  1.04s/it][A
69it [01:11,  1.04s/it][A
70it [01:12,  1.04s/it][A


71it [01:13,  1.05s/it][A
72it [01:14,  1.05s/it][A
73it [01:15,  1.04s/it][A
74it [01:17,  1.04s/it][A
75it [01:18,  1.04s/it][A
76it [01:19,  1.04s/it][A
77it [01:20,  1.05s/it][A
78it [01:21,  1.04s/it][A
79it [01:22,  1.04s/it][A
80it [01:23,  1.04s/it][A


81it [01:24,  1.05s/it][A
82it [01:25,  1.04s/it][A
83it [01:26,  1.04s/it][A
84it [01:27,  1.04s/it][A
85it [01:28,  1.04s/it][A
86it [01:29,  1.04s/it][A
87it [01:30,  1.04s/it][A
88it [01:31,  1.04s/it][A
 75%|████████████████████▎      | 3/4 [04:33<01:31, 91.23s/it]
0it [00:00, ?it/s][A


1it [00:00,  1.15it/s][A
2it [00:01,  1.03it/s][A
3it [00:02,  1.00it/s][A
4it [00:03,  1.01s/it][A
5it [00:05,  1.01s/it][A
6it [00:06,  1.02s/it][A
7it [00:07,  1.03s/it][A
8it [00:08,  1.03s/it][A
9it [00:09,  1.03s/it][A
10it [00:10,  1.03s/it][A


11it [00:11,  1.04s/it][A
12it [00:12,  1.04s/it][A
13it [00:13,  1.04s/it][A
14it [00:14,  1.04s/it][A
15it [00:15,  1.03s/it][A
16it [00:16,  1.04s/it][A
17it [00:17,  1.04s/it][A
18it [00:18,  1.05s/it][A
19it [00:19,  1.04s/it][A
20it [00:20,  1.04s/it][A


21it [00:21,  1.05s/it][A
22it [00:22,  1.04s/it][A
23it [00:23,  1.04s/it][A
24it [00:24,  1.04s/it][A
25it [00:25,  1.04s/it][A
26it [00:26,  1.05s/it][A
27it [00:27,  1.05s/it][A
28it [00:28,  1.04s/it][A
29it [00:30,  1.04s/it][A
30it [00:31,  1.04s/it][A


31it [00:32,  1.05s/it][A
32it [00:33,  1.04s/it][A
33it [00:34,  1.04s/it][A
34it [00:35,  1.04s/it][A
35it [00:36,  1.04s/it][A
36it [00:37,  1.04s/it][A
37it [00:38,  1.04s/it][A
38it [00:39,  1.04s/it][A
39it [00:40,  1.04s/it][A
40it [00:41,  1.03s/it][A


41it [00:42,  1.05s/it][A
42it [00:43,  1.04s/it][A
43it [00:44,  1.05s/it][A
44it [00:45,  1.04s/it][A
45it [00:46,  1.05s/it][A
46it [00:47,  1.04s/it][A
47it [00:48,  1.04s/it][A
48it [00:49,  1.04s/it][A
49it [00:50,  1.03s/it][A
50it [00:51,  1.04s/it][A


51it [00:52,  1.05s/it][A
52it [00:53,  1.05s/it][A
53it [00:55,  1.05s/it][A
54it [00:56,  1.04s/it][A
55it [00:57,  1.06s/it][A
56it [00:58,  1.06s/it][A
57it [00:59,  1.05s/it][A
58it [01:00,  1.05s/it][A
59it [01:01,  1.05s/it][A
60it [01:02,  1.04s/it][A


61it [01:03,  1.06s/it][A
62it [01:04,  1.06s/it][A
63it [01:05,  1.06s/it][A
64it [01:06,  1.06s/it][A
65it [01:07,  1.05s/it][A
66it [01:08,  1.05s/it][A
67it [01:09,  1.05s/it][A
68it [01:10,  1.05s/it][A
69it [01:11,  1.05s/it][A
70it [01:12,  1.05s/it][A


71it [01:13,  1.06s/it][A
72it [01:15,  1.05s/it][A
73it [01:16,  1.05s/it][A
74it [01:17,  1.05s/it][A
75it [01:18,  1.05s/it][A
76it [01:19,  1.05s/it][A
77it [01:20,  1.05s/it][A
78it [01:21,  1.04s/it][A
79it [01:22,  1.04s/it][A
80it [01:23,  1.04s/it][A


81it [01:24,  1.05s/it][A
82it [01:25,  1.04s/it][A
83it [01:26,  1.04s/it][A
84it [01:27,  1.04s/it][A
85it [01:28,  1.04s/it][A
86it [01:29,  1.04s/it][A
87it [01:30,  1.04s/it][A
88it [01:31,  1.04s/it][A
100%|███████████████████████████| 4/4 [06:05<00:00, 91.29s/it]


  0%|                                   | 0/5 [00:00<?, ?it/s]
0it [00:00, ?it/s][A


1it [00:15, 15.55s/it][A
2it [00:32, 16.45s/it][A
3it [00:48, 15.96s/it][A
4it [01:04, 16.20s/it][A
5it [01:19, 15.90s/it][A
6it [01:39, 17.12s/it][A
7it [01:54, 16.51s/it][A
8it [02:10, 16.29s/it][A
9it [02:24, 15.51s/it][A
10it [02:40, 15.74s/it][A


11it [02:56, 15.80s/it][A
12it [03:11, 15.71s/it][A
13it [03:29, 16.13s/it][A
14it [03:46, 16.38s/it][A
15it [04:02, 16.53s/it][A
16it [04:19, 16.50s/it][A
17it [04:35, 16.53s/it][A
18it [04:51, 16.33s/it][A
19it [05:08, 16.50s/it][A
20it [05:23, 16.02s/it][A


21it [05:40, 16.27s/it][A
22it [05:54, 16.12s/it][A
 20%|█████▏                    | 1/5 [05:54<23:39, 354.77s/it]
0it [00:00, ?it/s][A


1it [00:15, 15.26s/it][A
2it [00:32, 16.36s/it][A
3it [00:47, 15.99s/it][A
4it [01:04, 16.28s/it][A
5it [01:20, 15.96s/it][A
6it [01:39, 17.14s/it][A
7it [01:54, 16.49s/it][A
8it [02:10, 16.27s/it][A
9it [02:24, 15.52s/it][A
10it [02:40, 15.77s/it][A


11it [02:56, 15.85s/it][A
12it [03:12, 15.74s/it][A
13it [03:29, 16.11s/it][A
14it [03:46, 16.34s/it][A
15it [04:02, 16.50s/it][A
16it [04:19, 16.44s/it][A
17it [04:35, 16.46s/it][A
18it [04:51, 16.33s/it][A
19it [05:08, 16.48s/it][A
20it [05:23, 16.00s/it][A


21it [05:40, 16.23s/it][A
22it [05:54, 16.11s/it][A
 40%|██████████▍               | 2/5 [11:49<17:43, 354.65s/it]
0it [00:00, ?it/s][A


1it [00:15, 15.36s/it][A
2it [00:32, 16.54s/it][A
3it [00:48, 16.01s/it][A
4it [01:04, 16.35s/it][A
5it [01:20, 16.05s/it][A
6it [01:40, 17.30s/it][A
7it [01:55, 16.63s/it][A
8it [02:11, 16.40s/it][A
9it [02:25, 15.60s/it][A
10it [02:41, 15.83s/it][A


11it [02:57, 15.82s/it][A
12it [03:12, 15.72s/it][A
13it [03:29, 16.11s/it][A
14it [03:46, 16.35s/it][A
15it [04:03, 16.53s/it][A
16it [04:20, 16.51s/it][A
17it [04:36, 16.56s/it][A
18it [04:53, 16.45s/it][A
19it [05:09, 16.50s/it][A
20it [05:24, 16.02s/it][A


21it [05:41, 16.22s/it][A
22it [05:55, 16.16s/it][A
 60%|███████████████▌          | 3/5 [17:44<11:50, 355.06s/it]
0it [00:00, ?it/s][A


1it [00:15, 15.40s/it][A
2it [00:32, 16.53s/it][A
3it [00:48, 16.04s/it][A
4it [01:04, 16.31s/it][A
5it [01:20, 15.94s/it][A
6it [01:39, 17.05s/it][A
7it [01:54, 16.36s/it][A
8it [02:09, 16.13s/it][A
9it [02:23, 15.39s/it][A
10it [02:40, 15.66s/it][A


11it [02:55, 15.69s/it][A
12it [03:10, 15.50s/it][A
13it [03:27, 15.96s/it][A
14it [03:44, 16.19s/it][A
15it [04:00, 16.24s/it][A
16it [04:16, 16.18s/it][A
17it [04:33, 16.22s/it][A
18it [04:49, 16.13s/it][A
19it [05:05, 16.28s/it][A
20it [05:20, 15.79s/it][A


21it [05:37, 16.04s/it][A
22it [05:51, 15.97s/it][A
 80%|████████████████████▊     | 4/5 [23:36<05:53, 353.59s/it]
0it [00:00, ?it/s][A


1it [00:15, 15.03s/it][A
2it [00:32, 16.35s/it][A
3it [00:47, 15.87s/it][A
4it [01:04, 16.18s/it][A
5it [01:19, 15.75s/it][A
6it [01:38, 16.96s/it][A
7it [01:53, 16.29s/it][A
8it [02:08, 16.00s/it][A
9it [02:22, 15.31s/it][A
10it [02:38, 15.60s/it][A


11it [02:54, 15.62s/it][A
12it [03:09, 15.51s/it][A
13it [03:26, 15.89s/it][A
14it [03:43, 16.13s/it][A
15it [03:59, 16.27s/it][A
16it [04:15, 16.17s/it][A
17it [04:32, 16.26s/it][A
18it [04:48, 16.15s/it][A
19it [05:04, 16.30s/it][A
20it [05:19, 15.77s/it][A


21it [05:35, 16.04s/it][A
22it [05:50, 15.91s/it][A
100%|██████████████████████████| 5/5 [29:26<00:00, 353.27s/it]


[{'generated_text': 'Wiederaufnahme der Sitzungsperiode'}]


In [32]:
output_dir="./FAMNet/HoSoTAs/en-de-new"
path = os.path.join(output_dir, "model_files")
trained_tokenizer = T5Tokenizer.from_pretrained(path)
trained_model = T5ForConditionalGeneration.from_pretrained(path)

In [33]:
text = 'Resumption of the session'
text1 = 'I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.'
inference(trained_model=trained_model, trained_tokenizer=trained_tokenizer, text=text1)

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [21]:
# Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen 
# Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, 
# daß Sie schöne Ferien hatten.

# BLEU score
# from nltk.translate.bleu_score import sentence_bleu
def compute_bleu(y_pred, y_true):
    metric = load_metric('bleu')
    metric.add_batch(predictions=y_pred, references=y_true)
    report = metric.compute()
    bleu = report['bleu'] * 100
    return bleu

In [22]:
def evaluation(loader, model, tokenizer, device):
    y_true = []
    y_pred = []
    for i, batch in enumerate(loader):

        # Prepare and tokenize the source sentences
        src_sentences = [prefix + line for line in batch[args().src_language]]
        encoded_input = tokenizer(src_sentences, max_length=128,
                                  padding=True, truncation=True,
                                  return_tensors='pt', add_special_tokens=True).input_ids.to(device)

        # Translate and decode the inputs
        outputs = model.generate(encoded_input, max_length=175)
        batch_pred = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Concatenate the translated and reference sentences
        for sentence in batch[args().tgt_language]:
            sentence = tokenizer.tokenize(sentence)
            # print(sentence)
            y_true.append([sentence])
        for sentence in batch_pred:
            sentence = tokenizer.tokenize(sentence)
            # print(sentence)
            y_pred.append(sentence)

    bleu = compute_bleu(y_pred, y_true)
    print('Bleu Score: {:.2f}'.format(bleu))

In [23]:
!nvidia-smi

Sat Mar 25 08:06:57 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.106.00   Driver Version: 460.106.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:06.0 Off |                    0 |
| N/A   55C    P0    45W / 250W |  10282MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [24]:
torch.cuda.empty_cache()
!nvidia-smi

Sat Mar 25 08:06:58 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.106.00   Driver Version: 460.106.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:06.0 Off |                    0 |
| N/A   55C    P0    45W / 250W |   4156MiB / 16280MiB |     46%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [25]:
from numba import cuda 
device = cuda.get_current_device()
device.reset()

In [31]:
CUDA_LAUNCH_BLOCKING=1