In [1]:
import os
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from datetime import datetime
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [3]:
%cd ..
from src import envs
from src.sl.losses import FocalLoss
from src.sl.dataset import GECDataset
from src.sl.utils import process_data, collate_func
from src.utils import load_text, write_json, freeze_params
from src.models.seq2labels import PretrainedEncoder, Seq2Labels
%cd notebooks

/home/rajk/Machine_Learning/DRL-GEC
/home/rajk/Machine_Learning/DRL-GEC/notebooks


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Local Functions

In [5]:
@torch.cuda.amp.autocast()
def evaluate(model, batch, criterion):
    masks = torch.from_numpy(batch["masks"]).to(device)
    labels = torch.from_numpy(batch['labels']).to(device)
    logits = model(tokens=batch['tokens'])
    batch_size, seq_size, labels_size = logits.shape
    loss = criterion(logits.view(-1, labels_size), labels.view(-1))
    loss = loss.view(batch_size, seq_size).sum(dim=-1)
    return loss

# Define parameters

In [6]:
cold_lr = 1e-3
warm_lr = 1e-5
lr = cold_lr
dropout = 0.1
num_epochs = 20
cold_epochs = 2
patience = 3
batch_size = 64
accumulation_size = 4
weight_decay = 0
data_limit = 500_000
keep_corrects = False
num_unfreeze_layers = 0
train_datasets = ["synthetic"]
val_datasets = ["synthetic"]
current_datetime = datetime.now().strftime("%d:%m:%Y_%H:%M")
model_path = None
train_type = "pretrain" if model_path is None else "finetune"
log_dir = os.path.join("sl_logs", f"{train_type}_{'-'.join(train_datasets)}_{current_datetime}")
writer = SummaryWriter(log_dir=log_dir)
meta_data = {
    "description": """
    Pretrain on 500k Synthetic data.
    Use Focal Loss.
    Use Unknown label
    """
}

# Load label vocabulary

In [7]:
label_path = "../data/vocabs/labels.txt"
label_vocab = load_text(label_path)
label2index = {label:i for i, label in enumerate(label_vocab)}

# Load raw data

In [8]:
train_data = []
for dataset in tqdm(train_datasets, desc="Loading datasets", total=len(train_datasets)):
    data_path = f"../data/processed/{dataset}/data.gector"
    train_data.extend(load_text(data_path))
if (data_limit > 0) and (len(train_data) > data_limit):
    print(f"Truncating amount of data from {len(train_data)} to {data_limit}")
    train_data = train_data[:data_limit]
print(f"Total number of sentences: {len(train_data)}")

Loading datasets:   0%|          | 0/1 [00:00<?, ?it/s]

Truncating amount of data from 3290450 to 500000
Total number of sentences: 500000


In [9]:
dev_data = []
for dataset in tqdm(val_datasets, desc="Loading datasets", total=len(val_datasets)):
    data_path = f"../data/processed/{dataset}/dev.gector"
    dev_data.extend(load_text(data_path))
print(f"Total number of sentences: {len(dev_data)}")

Loading datasets:   0%|          | 0/1 [00:00<?, ?it/s]

Total number of sentences: 40000


# Extract tokens and labels from the raw data

In [10]:
train_tokens, train_labels = process_data(train_data, label_vocab, keep_corrects=keep_corrects)
dev_tokens, dev_labels = process_data(dev_data, label_vocab, keep_corrects=True)
train_dataset = GECDataset(train_tokens, train_labels, label2index)
dev_dataset = GECDataset(dev_tokens, dev_labels, label2index)
train_loader = DataLoader(train_dataset, batch_size=int(batch_size/accumulation_size), shuffle=True, num_workers=4, collate_fn=collate_func)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_func)

Processing data:   0%|          | 0/500000 [00:00<?, ?it/s]

Amount of data after filtering: 470654


Processing data:   0%|          | 0/40000 [00:00<?, ?it/s]

Amount of data after filtering: 40000


In [11]:
model_name = "roberta-base"
tokenizer_config = {"use_fast": True}
transformer_config = {"output_attentions": False}

encoder = PretrainedEncoder(model_name, tokenizer_config, transformer_config).to(device)
model = Seq2Labels(encoder_model=encoder, num_labels=len(label_vocab), dropout=dropout).to(device)
if model_path:
    model.load_state_dict(torch.load(model_path))
criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction="none")
# criterion = nn.CrossEntropyLoss(reduction="none")
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
grad_scaler = torch.cuda.amp.GradScaler()
write_json(os.path.join(log_dir, "meta.json"), meta_data)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.weight', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
model.train()
optim.zero_grad()
N = len(train_loader)
freeze_params(model.encoder, requires_grad=False)    # Freeze encoder model
# Log hyperparameters
writer.add_scalar("hyperparameters/dropout", dropout, 0)
writer.add_scalar("hyperparameters/patience", patience, 0)
writer.add_scalar("hyperparameters/batch_size", batch_size, 0)
writer.add_scalar("hyperparameters/cold_epochs", cold_epochs, 0)
writer.add_scalar("hyperparameters/weight_decay", weight_decay, 0)
writer.add_scalar("hyperparameters/keep_corrects", int(keep_corrects), 0)
writer.add_scalar("hyperparameters/accumulation_size", accumulation_size, 0)
writer.add_scalar("hyperparameters/num_unfreeze_layers", num_unfreeze_layers, 0)
writer.add_scalar("hyperparameters/uses_CE_loss", int(isinstance(criterion, nn.CrossEntropyLoss)), 0)

with torch.no_grad():
    dev_losses = [evaluate(model, batch, criterion) for batch in dev_loader]
    dev_loss = torch.cat(dev_losses).mean()
    writer.add_scalar("sl/validation_loss", dev_loss, 0)

epochs_since_improvement = 0
best_dev_score = dev_loss
for epoch in tqdm(range(num_epochs), desc="Training", total=num_epochs):
    if epoch == cold_epochs:                                                                           # End of the cold epochs
        lr = warm_lr
        freeze_params(model.encoder, requires_grad=True, num_layers=num_unfreeze_layers, optim=optim, lr=lr)               # Unfreeze encoder model
    
    step_offset = epoch*N
    for i, batch in tqdm(enumerate(train_loader), desc=f"Epoch {epoch+1}", total=len(train_loader)):
        loss = evaluate(model, batch, criterion)
        loss = loss.mean()
        grad_scaler.scale(loss/accumulation_size).backward()
        if ((i+1) % accumulation_size) == 0:
            grad_scaler.step(optim)
            grad_scaler.update()
            optim.zero_grad()
            torch.cuda.empty_cache()
        writer.add_scalar("sl/lr", lr, step_offset + i)
        writer.add_scalar("sl/train_loss", loss, step_offset + i)
    with torch.no_grad():
        dev_losses = [evaluate(model, batch, criterion) for batch in dev_loader]
    dev_loss = torch.cat(dev_losses).mean()
    writer.add_scalar("sl/validation_loss", dev_loss, step_offset + i)
    if dev_loss <= best_dev_score:
        best_dev_score = dev_loss
        epochs_since_improvement = 0
        torch.save(model.state_dict(), os.path.join(log_dir, "model-best.pt"))     # Save best model 
    else:
        epochs_since_improvement += 1
        if epochs_since_improvement >= patience:
            print("Early stopping!")
            break

Number of frozen parameters: 197/197


Training:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/29416 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/29416 [00:00<?, ?it/s]

Number of frozen parameters: 0/197


Epoch 3:   0%|          | 0/29416 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/29416 [00:00<?, ?it/s]

In [None]:
torch.save(model.state_dict(), os.path.join(log_dir, "model-last.pt"))                 # Save last model

model.load_state_dict(torch.load(os.path.join(log_dir, "model-best.pt")))

In [None]:
import gym

In [None]:
env = gym.make("wi_locness_gec-v0", correct_examples_percent=[0.0], repeat=1, min_num_refs=[1])

In [None]:
model.eval()
state = env.reset()
print("# References")
for ref in env.reference_tokens_list:
    print(ref)
print()
done = False
while not done:
    with torch.no_grad():
        [logits] = model([state])
        actions = logits.argmax(-1)
        actions = actions.cpu().numpy()
        v, i = logits.topk(5)
        v = v.cpu().numpy()
        i = i.cpu().numpy()
        for s, lp in zip(state, zip(env.labels[i], v)):
            print(f"{s:15}", " --- ".join(f"{l:9} [{p:5.2f}]" for (l, p) in zip(*lp)))
        print()
    next_state, reward, done, info = env.step(actions)
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

# Close Google Compute Instance

In [None]:
!gcloud compute instances stop drl-gec --zone us-west1-b