In [1]:
# Init/Load model
from transformers import BartForConditionalGeneration, BartTokenizer
import numpy as np
import torch
import os

device = "cuda"

# Define a directory to save the models
SAVE_DIR = '../saved_models'
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

start_epoch = 36

class SimpleBART(torch.nn.Module):
    def __init__(self):
        super(SimpleBART, self).__init__()
        if start_epoch > 0:
            self.bart = BartForConditionalGeneration.from_pretrained(os.path.join(SAVE_DIR, f'epoch_{start_epoch}'))
            print(f'Loaded epoch_{start_epoch}')
        else:
            self.bart = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
            print('Loaded facebook/bart-base')
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

    def forward(self, input_ids, attention_mask):
        return self.bart(input_ids=input_ids, attention_mask=attention_mask)


model = SimpleBART().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()

Loaded epoch_36


In [6]:
# Load Raw Datasets
from torch.utils.data import Dataset, DataLoader
import csv

ACCUMULATION_STEPS = 14
BATCH_SIZE = 14 # best performing batch size so far (in execution performance)
DATA_SIZE = 0

class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=200):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.pad_token_id = tokenizer.pad_token_id
        self.start_token_id = tokenizer.cls_token_id
        self.end_token_id = tokenizer.eos_token_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, tokens = self.data[idx]
        
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.max_length)
        
        # Add start and end tokens and then pad
        tokens = [self.start_token_id] + tokens + [self.end_token_id]
        tokens_padded = [self.pad_token_id] * self.max_length
        tokens_padded[:len(tokens)] = tokens
        tokens_padded[len(tokens):] = [self.pad_token_id] * (self.max_length - len(tokens))
        
        return inputs["input_ids"].squeeze(0), inputs["attention_mask"].squeeze(0), torch.tensor(tokens_padded, dtype=torch.long)


def load_data_from_csv(file_path):
    with open(file_path, 'r') as file:
        reader = csv.reader(file)
        data = [(row[0], [int(tok) for tok in row[1].split(",")]) for row in reader]

    return data

def apply_concept(params, validation=False):
    global validationLoader
    global dataloader
    global DATA_SIZE

    merged_data = load_data_from_csv(f"../concept/egg.csv")
    
    print("Concepts loaded;")
    for file_name, percentage in params.items():
        data = load_data_from_csv(f"../concept/{file_name}.csv")
        cutoff = int(len(data) * percentage)
        
        if validation:
            loaded_data = data[-cutoff:]
        else:
            loaded_data = data[:cutoff]

        print(f"    - {file_name}: {len(loaded_data)}")
        merged_data.extend(loaded_data)

    DATA_SIZE = len(merged_data)
    dataset = CustomDataset(merged_data, model.tokenizer)
    if validation:
        validationLoader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    else:
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    print(f'  Total: {DATA_SIZE}')

In [3]:
# Setup Validator
def validate():
    EOS_TOKEN_ID = model.tokenizer.eos_token_id

    model.eval()

    # Initialize counters for accuracy calculation
    total_correct_sequences = 0
    total_sequences = 0

    # Initialize counters for average sequence accuracy within the mask
    total_accuracy = 0

    with torch.no_grad():
        for batch_idx, (input_ids, attention_mask, targets) in enumerate(validationLoader):
            input_ids, attention_mask, targets = input_ids.to(device), attention_mask.to(device), targets.to(device)
            outputs = model(input_ids, attention_mask)
            logits = outputs.logits

            # Identify where the EOS token is in the target sequence
            eos_positions = (targets == EOS_TOKEN_ID).cumsum(dim=1).type(torch.bool)
            mask = ~eos_positions | (targets == EOS_TOKEN_ID)

            _, predicted = logits.max(2)
            correct_sequences = ((predicted == targets) | ~mask).all(dim=1).float().sum().item()
            total_sequences += targets.size(0)
            total_correct_sequences += correct_sequences

            # Compute the accuracy for each sequence
            correct_tokens_per_sequence = ((predicted == targets) & mask).float().sum(dim=1)
            total_tokens_per_sequence = mask.float().sum(dim=1)
            total_accuracy += (correct_tokens_per_sequence / total_tokens_per_sequence).sum().item()

    # Compute and print the accuracy for the entire validation dataset
    validation_accuracy = total_correct_sequences / total_sequences
    print(f"  Total Seq Acc: {validation_accuracy*100:.3f}%")
    avg_accuracy = total_accuracy / total_sequences
    print(f"    Avg Seq Acc: {avg_accuracy*100:.3f}%")

    return avg_accuracy, validation_accuracy

In [4]:
# Setup trainer
def save_model():
    global start_epoch
    model_save_path = os.path.join(SAVE_DIR, f'epoch_{start_epoch}')
    model.bart.save_pretrained(model_save_path)

def trainFor(num_epochs, target_loss=0, target_p50_loss=0, target_acc=1.0, target_seq_acc=1.0):
    global start_epoch

    EOS_TOKEN_ID = model.tokenizer.eos_token_id
    acc_batch = int(ACCUMULATION_STEPS / BATCH_SIZE)
    total_batches = len(dataloader)

    for epoch in range(start_epoch+1, start_epoch+num_epochs+1):
        start_epoch = epoch
        model.train()

        # Resetting the accumulated gradients
        optimizer.zero_grad()

        # Initialize counters for accuracy calculation
        total_correct_sequences = 0
        total_sequences = 0
        cumulative_loss = 0.0

        # Initialize list to store batch losses
        batch_losses = []

        for batch_idx, (input_ids, attention_mask, targets) in enumerate(dataloader):

            input_ids, attention_mask, targets = input_ids.to(device), attention_mask.to(device), targets.to(device)
            outputs = model(input_ids, attention_mask)
            logits = outputs.logits

            # Identify where the EOS token is in the target sequence
            eos_positions = (targets == EOS_TOKEN_ID).cumsum(dim=1).type(torch.bool)
            mask = ~eos_positions | (targets == EOS_TOKEN_ID)

            # Apply mask to filter out tokens after the EOS token for loss computation
            active_loss = mask.view(-1).bool()
            active_logits = logits.view(-1, logits.size(-1))[active_loss]
            active_labels = targets.view(-1)[active_loss]
            loss = criterion(active_logits, active_labels)

            _, predicted = logits.max(2)
            correct_sequences = ((predicted == targets) | ~mask).all(dim=1).float().sum().item()
            total_sequences += targets.size(0)
            total_correct_sequences += correct_sequences

            # Accumulate the gradients
            loss.backward()
            loss_val = loss.item()

            cumulative_loss += loss_val
            batch_losses.append(loss_val)

            isLast = batch_idx == len(dataloader) - 1

            # Only perform an optimization step every ACCUMULATION_STEPS
            if isLast or batch_idx % acc_batch == 0:
                optimizer.step()
                optimizer.zero_grad()
            
            print(f"\rEpoch: {epoch}, Batch: {batch_idx} of {total_batches}, loss: {loss_val:.6f}      ", end='')


        # Compute and print the accuracy for the entire epoch
        epoch_accuracy = total_correct_sequences / total_sequences
        cumulative_loss = cumulative_loss / len(batch_losses)
        p25_loss = np.percentile(batch_losses, 25)
        p50_loss = np.percentile(batch_losses, 50)
        p75_loss = np.percentile(batch_losses, 75)
        print(f"\rEpoch: {epoch}, Accuracy: {epoch_accuracy*100:.2f}%                                             ")
        print(f"  Loss: 25%: {p25_loss:.6f} 50%: {p50_loss:.6f} 75%: {p75_loss:.6f}\n   Avg: {cumulative_loss:.6f}")

        if epoch % 10 == 0: # Save the model
            save_model()

        seq_acc, total_acc = validate()

        if total_acc >= target_acc:
            break
        if cumulative_loss <= target_loss:
            break
        if p50_loss <= target_p50_loss:
            break

        if seq_acc >= target_seq_acc:
            break

    # Make sure last epoch is always saved
    save_model()

In [5]:
ACCUMULATION_STEPS = BATCH_SIZE # pure memorisation so accumulation won't help
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
apply_concept({"vocabulary": 1.0, "noise": 0.02})
trainFor(50, target_seq_acc=0.50)
# (31mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
Concepts loaded;
    - vocabulary: 6808
    - noise: 3122
  Total: 9931
Epoch: 1, Accuracy: 0.01%
  Loss: 25%: 4.393126 50%: 4.887876 75%: 5.315798
   Avg: 4.884056
  Total Seq Acc: 0.000%
    Avg Seq Acc: 39.868%
Epoch: 2, Accuracy: 0.03%
  Loss: 25%: 4.117525 50%: 4.550774 75%: 4.877478
   Avg: 4.482987
  Total Seq Acc: 0.004%
    Avg Seq Acc: 39.875%
Epoch: 3, Accuracy: 0.03%
  Loss: 25%: 4.068293 50%: 4.452379 75%: 4.818588
   Avg: 4.398917
  Total Seq Acc: 0.062%
    Avg Seq Acc: 39.897%
Epoch: 4, Accuracy: 0.05%
  Loss: 25%: 4.008659 50%: 4.428133 75%: 4.759971
   Avg: 4.360003
  Total Seq Acc: 0.085%
    Avg Seq Acc: 39.924%
Epoch: 5, Accuracy: 0.16%
  Loss: 25%: 3.946199 50%: 4.349040 75%: 4.712277
   Avg: 4.299947
  Total Seq Acc: 0.214%
    Avg Seq Acc: 40.006%
Epoch: 6, Accuracy: 0.48%
  Loss: 25%: 3.817751 50%: 4.224310 75%: 4.577750
   Avg: 4.175182
  Total Seq Acc: 1.057%
    Avg Seq Acc: 40.588%
Ep

In [6]:
ACCUMULATION_STEPS = BATCH_SIZE # pure memorisation so accumulation won't help
apply_concept({"vocabulary": 1.0, "noise": 0.04})
trainFor(50, target_seq_acc=0.60)
# (4mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 6245
  Total: 13054
Epoch: 9, Accuracy: 10.68%
  Loss: 25%: 2.948524 50%: 3.301470 75%: 3.602107
   Avg: 3.285144
  Total Seq Acc: 16.104%
    Avg Seq Acc: 63.344%


In [7]:
ACCUMULATION_STEPS = 28 # *14 close to 32
apply_concept({"vocabulary": 1.0, "noise": 0.06})
trainFor(50, target_seq_acc=0.65)
# (5mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 9368
  Total: 16177
Epoch: 10, Accuracy: 14.79%
  Loss: 25%: 2.290249 50%: 2.601158 75%: 2.916466
   Avg: 2.600121
  Total Seq Acc: 25.349%
    Avg Seq Acc: 75.091%


Aiming for:
$$
\begin{align*}
  \frac{unique}{tokens} &= \frac{4174}{6808} = 61.31\%  & \text{sign pairs to text only used once} \\
  \frac{text}{tokens}   &= \frac{5342}{6808}  = 78.47\% & \text{sign pairs to text unique text} \\
  & & \text{unique meaning the text is only used for one tokenID}
\end{align*}
$$

i.e. there are six different signs which can be used for "present"  
which means we're actually aiming for $96\%$ effective accuracy $\frac{75\%}{78\%}$

But we don't want to over-fit either  
Hence why we slowly introduce new concepts while still memorising vocabulary

`target_seq_acc` includes the `EOS` token, which we of course we want to be right
So our target should be $\frac{61.31\% + 100\%}{2} = 80.65\%$

In [8]:
ACCUMULATION_STEPS = 70 # *14 close to 64
apply_concept({"vocabulary": 1.0, "noise": 0.08})
trainFor(50, target_seq_acc=0.80655)
# (10mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 12491
  Total: 19300
Epoch: 11, Accuracy: 18.79%
  Loss: 25%: 1.740285 50%: 1.988600 75%: 2.250141
   Avg: 2.004923
  Total Seq Acc: 31.401%
    Avg Seq Acc: 79.550%
Epoch: 12, Accuracy: 23.83%
  Loss: 25%: 1.321281 50%: 1.537223 75%: 1.742366
   Avg: 1.547170
  Total Seq Acc: 35.058%
    Avg Seq Acc: 81.933%


In [9]:
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
validate()
apply_concept({"vocabulary": 1.0}, validation=True)
validate()
apply_concept({"noise": 0.1}, validation=True)
validate()

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
  Total Seq Acc: 35.058%
    Avg Seq Acc: 81.933%
Concepts loaded;
    - vocabulary: 6808
  Total: 6809
  Total Seq Acc: 62.109%
    Avg Seq Acc: 87.325%
Concepts loaded;
    - noise: 15614
  Total: 15615
  Total Seq Acc: 23.260%
    Avg Seq Acc: 79.579%


(0.795788674411047, 0.23259686199167467)

The pure memorisation is over, as the vocabulary has been sufficiently learnt  
While it's not perfect, it will improve further over the later training, as the vocab remains in the training set

In [11]:
ACCUMULATION_STEPS = 1022 # *14 close to 1024
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
apply_concept({"vocabulary": 1.0, "noise": 0.225})
trainFor(50, target_seq_acc=0.6)
# (10mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
Concepts loaded;
    - vocabulary: 6808
    - noise: 35132
  Total: 41941
Epoch: 13, Accuracy: 19.62%
  Loss: 25%: 1.276297 50%: 1.521877 75%: 1.806241
   Avg: 1.567434
  Total Seq Acc: 37.671%
    Avg Seq Acc: 83.350%


In [12]:
# Break down of validation per concept
apply_concept({"vocabulary": 1.0}, validation=True)
validate()
apply_concept({"noise": 0.1}, validation=True)
validate()

Concepts loaded;
    - vocabulary: 6808
  Total: 6809
  Total Seq Acc: 64.576%
    Avg Seq Acc: 88.098%
Concepts loaded;
    - noise: 15614
  Total: 15615
  Total Seq Acc: 25.937%
    Avg Seq Acc: 81.276%


(0.8127563179130933, 0.25936599423631124)

In [13]:
ACCUMULATION_STEPS = 1022 # *14 close to 1024
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
apply_concept({"vocabulary": 1.0, "noise": 0.45})
trainFor(50, target_seq_acc=0.7)
# (16mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
Concepts loaded;
    - vocabulary: 6808
    - noise: 70264
  Total: 77073
Epoch: 14, Accuracy: 21.77%
  Loss: 25%: 0.978782 50%: 1.179458 75%: 1.420422
   Avg: 1.226024
  Total Seq Acc: 42.332%
    Avg Seq Acc: 86.016%


In [14]:
# Validate the model saved correctly
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
validate()
apply_concept({"vocabulary": 1.0}, validation=True)
validate()
apply_concept({"noise": 0.1}, validation=True)
validate()

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
  Total Seq Acc: 42.332%
    Avg Seq Acc: 86.016%
Concepts loaded;
    - vocabulary: 6808
  Total: 6809
  Total Seq Acc: 67.190%
    Avg Seq Acc: 88.945%
Concepts loaded;
    - noise: 15614
  Total: 15615
  Total Seq Acc: 31.489%
    Avg Seq Acc: 84.735%


(0.8473481228669942, 0.3148895292987512)

In [15]:
ACCUMULATION_STEPS = 1022 # *14 close to 1024
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
apply_concept({"vocabulary": 1.0, "noise": 0.9})
trainFor(50, target_seq_acc=0.75)
# (29mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
Concepts loaded;
    - vocabulary: 6808
    - noise: 140528
  Total: 147337
Epoch: 15, Accuracy: 25.32%
  Loss: 25%: 0.735360 50%: 0.893032 75%: 1.094783
   Avg: 0.942284
  Total Seq Acc: 46.069%
    Avg Seq Acc: 87.619%


In [16]:
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
validate()
apply_concept({"vocabulary": 1.0}, validation=True)
validate()
apply_concept({"noise": 0.1}, validation=True)
validate()

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
  Total Seq Acc: 46.069%
    Avg Seq Acc: 87.619%
Concepts loaded;
    - vocabulary: 6808
  Total: 6809
  Total Seq Acc: 69.863%
    Avg Seq Acc: 89.861%
Concepts loaded;
    - noise: 15614
  Total: 15615
  Total Seq Acc: 35.690%
    Avg Seq Acc: 86.637%


(0.8663697381680721, 0.3569004162664105)

In [17]:
ACCUMULATION_STEPS = 1022 # *14 close to 1024
apply_concept({"vocabulary": 1.0, "noise": 0.9})
trainFor(50, target_acc=0.8)
# (583mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 140528
  Total: 147337
Epoch: 16, Accuracy: 29.25%
  Loss: 25%: 0.571872 50%: 0.698485 75%: 0.865685
   Avg: 0.749544
  Total Seq Acc: 36.574%
    Avg Seq Acc: 87.146%
Epoch: 17, Accuracy: 31.80%
  Loss: 25%: 0.486688 50%: 0.593760 75%: 0.741733
   Avg: 0.647466
  Total Seq Acc: 38.834%
    Avg Seq Acc: 87.917%
Epoch: 18, Accuracy: 33.24%
  Loss: 25%: 0.436313 50%: 0.527889 75%: 0.659536
   Avg: 0.579617
  Total Seq Acc: 39.065%
    Avg Seq Acc: 88.180%
Epoch: 19, Accuracy: 34.66%
  Loss: 25%: 0.394753 50%: 0.478254 75%: 0.594729
   Avg: 0.527243
  Total Seq Acc: 39.577%
    Avg Seq Acc: 88.354%
Epoch: 20, Accuracy: 35.63%
  Loss: 25%: 0.367282 50%: 0.446331 75%: 0.552213
   Avg: 0.491565
  Total Seq Acc: 39.801%
    Avg Seq Acc: 88.395%
Epoch: 21, Accuracy: 36.52%
  Loss: 25%: 0.344380 50%: 0.417024 75%: 0.515350
   Avg: 0.460924
  Total Seq Acc: 39.923%
    Avg Seq Acc: 88.536%
Epoch: 22, Accuracy: 37.15%
  Loss: 25%: 0.327414 50%:

In [18]:
save_model()

I accidentally set the `target_acc` which is accuracy of total sequences, instead of `target_seq_acc` which is the average accuracy of sequences, so this training stage had to be manually stopped

In [19]:
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
validate()
apply_concept({"vocabulary": 1.0}, validation=True)
validate()
apply_concept({"noise": 0.1}, validation=True)
validate()

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
  Total Seq Acc: 52.339%
    Avg Seq Acc: 90.122%
Concepts loaded;
    - vocabulary: 6808
  Total: 6809
  Total Seq Acc: 78.132%
    Avg Seq Acc: 92.685%
Concepts loaded;
    - noise: 15614
  Total: 15615
  Total Seq Acc: 41.089%
    Avg Seq Acc: 89.000%


(0.8899978608990418, 0.41088696765930194)

In [5]:
ACCUMULATION_STEPS = 1022 # *14 close to 1024
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
apply_concept({"vocabulary": 1.0, "noise": 0.9})
trainFor(50, target_p50_loss=0.2, target_seq_acc=0.90)
# (31mins)

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
Concepts loaded;
    - vocabulary: 6808
    - noise: 140528
  Total: 147337
Epoch: 36, Accuracy: 41.15%10525, loss: 0.603985      
  Loss: 25%: 0.235659 50%: 0.282068 75%: 0.340233
   Avg: 0.309650
  Total Seq Acc: 52.727%
    Avg Seq Acc: 90.170%


In [6]:
apply_concept({"vocabulary": 1.0, "noise": 0.1}, validation=True)
validate()
apply_concept({"vocabulary": 1.0}, validation=True)
validate()
apply_concept({"noise": 0.1}, validation=True)
validate()

Concepts loaded;
    - vocabulary: 6808
    - noise: 15614
  Total: 22423
  Total Seq Acc: 52.727%
    Avg Seq Acc: 90.170%
Concepts loaded;
    - vocabulary: 6808
  Total: 6809
  Total Seq Acc: 78.249%
    Avg Seq Acc: 92.715%
Concepts loaded;
    - noise: 15614
  Total: 15615
  Total Seq Acc: 41.595%
    Avg Seq Acc: 89.056%


(0.8905585711635232, 0.4159462055715658)

In [7]:
ACCUMULATION_STEPS = 1022 # *14 close to 1024
apply_concept({"vocabulary": 1.0, "grammar": 0.1}, validation=True)
apply_concept({"vocabulary": 1.0, "grammar": 0.9})
trainFor(3, target_p50_loss=0.2, target_seq_acc=0.90)
# (mins)

Concepts loaded;
    - vocabulary: 6808
    - grammar: 10
  Total: 6819
Concepts loaded;
    - vocabulary: 6808
    - grammar: 90
  Total: 6899
Epoch: 37, Accuracy: 72.87%                                             
  Loss: 25%: 0.206330 50%: 0.297163 75%: 0.455768
   Avg: 1.136580
  Total Seq Acc: 78.267%
    Avg Seq Acc: 92.671%


In [8]:
ACCUMULATION_STEPS = 1022 # *14 close to 1024
apply_concept({"grammar": 0.1}, validation=True)
apply_concept({"vocabulary": 1.0, "grammar": 0.9})
trainFor(3, target_p50_loss=0.2, target_seq_acc=0.90)
# (mins)

Concepts loaded;
    - grammar: 10
  Total: 11
Concepts loaded;
    - vocabulary: 6808
    - grammar: 90
  Total: 6899
Epoch: 38, Accuracy: 73.69%                                             
  Loss: 25%: 0.211418 50%: 0.300939 75%: 0.432739
   Avg: 0.994158
  Total Seq Acc: 0.000%
    Avg Seq Acc: 19.835%
Epoch: 39, Accuracy: 74.16%                                             
  Loss: 25%: 0.214725 50%: 0.280471 75%: 0.424626
   Avg: 0.985662
  Total Seq Acc: 0.000%
    Avg Seq Acc: 19.056%
Epoch: 40, Accuracy: 73.60%                                             
  Loss: 25%: 0.204188 50%: 0.284730 75%: 0.447475
   Avg: 0.963902
  Total Seq Acc: 0.000%
    Avg Seq Acc: 20.938%
