In [None]:
# QCB455/COS551 Final Project Pretraining Loss Visualization
# Author: Supraj Gunda

# The image that this file produces is shown in the directory as "pretrain_loss.png".
# The image is not finalized, but it shows work towards visualizing such a metric.
# The reason it is not finalized is because the regex from the slurm.out files produced by the cluster are not consistent.
# Furthermore, we thought that the analysis/visualizations we produced were already sufficient.

from glob import glob
import re
import numpy as np

In [None]:
# build on Colin's extract metrics method to produce .out files representing the cluster finetuning jobs

def extract_metrics(file):
    best_val_loss = None
    d_model = None
    n_layer = None
    bed_file = None
    fasta_file = None
    max_epochs = None
    current_epoch = None

    with open(file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            if 'd_model' in line:
                match = re.search(r'd_model:\s*(\d+)', line)
                if match:
                    d_model = int(match.group(1))
            elif 'n_layer' in line:
                match = re.search(r'n_layer:\s*(\d+)', line)
                if match:
                    n_layer = int(match.group(1))
            elif 'bed_file' in line:
                match = re.search(r'bed_file:\s*(\S+)', line)
                if match:
                    bed_file = match.group(1)
            elif 'fasta_file' in line:
                match = re.search(r'fasta_file:\s*(\S+)', line)
                if match:
                    fasta_file = match.group(1)
            elif 'max_epochs' in line:
                match = re.search(r'max_epochs:\s*(\d+)', line)
                if match:
                    max_epochs = int(match.group(1))
            elif 'best' in line:
                match = re.search(r'(?<=best )\d+\.\d+', line)
                if match:
                    curr_best = float(match.group(0))
                    if best_val_loss is None or curr_best < best_val_loss:
                        best_val_loss = curr_best
            elif 'Epoch' in line:
                match = re.search(r'Epoch\s+(\d+)', line)
                if match:
                    epoch = int(match.group(1))
                    if current_epoch is None or epoch > current_epoch:
                        current_epoch = epoch

        print(f'Width: {d_model}')
        print(f'Depth: {n_layer}')
        print(f'Bed file: {bed_file}')
        print(f'Fasta file: {fasta_file}')
        print(f'Max epochs: {max_epochs}')
        print(f'Epoch {current_epoch}')
        
        if best_val_loss is not None:
            ppl = np.exp(best_val_loss)
            print(f'PPL: {ppl}')
        else:
            ppl = None

        return d_model, n_layer, bed_file, fasta_file, max_epochs, ppl

directory = '/scratch/gpfs/sg0666/hyena-dna/slurm_pretrain'
files = glob(f'{directory}/*.out')

for file in files:
    extract_metrics(file)
    print()

In [None]:
import csv
import matplotlib.pyplot as plt
import re

In [None]:
# reads all of the .out files previously made

# results of the experiments
results = []

directory = '/Users/Supraj1/qcb455/loss/slurm_files/slurm_pretrain'

# all files named something.out
files = glob(f'{directory}/*.out')

for file in files:
    metrics = extract_metrics(file)
    print(f"File: {file}")
    print(f"Metrics: {metrics}")
    results.append(metrics)


# save this information to a CSV file
output_csv = 'metrics_summary_PT.csv'
header = ['d_model', 'n_layer', 'bed_file', 'fasta_file', 'max_epochs', 'PPL']

with open(output_csv, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(header)

    writer.writerows(results)

In [None]:
# regex (regular expression) algorithm to get the epoch losses
def extractLosses(file_path):
    trainLosses = []
    valLosses = []

    # read the .out file
    with open(file_path, 'r') as f:
        lines = f.readlines()

    currEpoch = None
    lastTrainLoss = None
    lastValLoss = None

    for line in lines:
        # if you find epoch with epoch loss, hold the value
        epochMatch = re.search(r"Epoch (\d+):", line)
        if epochMatch:
            epoch = int(epochMatch.group(1))

            # if moving to a different epoch, save the last losses
            if currEpoch is not None and currEpoch != epoch:
                trainLosses.append(lastTrainLoss)
                valLosses.append(lastValLoss)

            # update the current epoch
            currEpoch = epoch
            lastTrainLoss = None  # Reset train loss
            lastValLoss = None    # Reset validation loss

        # hold training loss if found
        trainLossMatch = re.search(r"loss=([\d.]+)", line)
        if trainLossMatch:
            lastTrainLoss = float(trainLossMatch.group(1))

        # hold validation loss if found
        valLossMatch = re.search(r"val/loss=([\d.]+)", line)
        if valLossMatch:
            lastValLoss = float(valLossMatch.group(1))

    # for final epoch
    if currEpoch is not None:
        trainLosses.append(lastTrainLoss)
        valLosses.append(lastValLoss)

    return trainLosses, valLosses

In [None]:
dir = '/Users/Supraj1/qcb455/loss/slurm_files/slurm_pretrain'

# to get all files with .out suffix
filePaths = glob(f'{dir}/*.out')
# start at index 1 instead of 0
fileNames = [f"File {i + 1}" for i in range(len(filePaths))]

# plotting
# this is number of rows
rows = (len(filePaths) + 1) // 2  

fig, axes = plt.subplots(rows, 2, figsize=(10, rows * 5))
axes = axes.flatten()

for i in range(len(filePaths)):
    filePath = filePaths[i]
    fileName = fileNames[i]
    trainLosses, valLosses = extractLosses(filePath)

    # check if extracted properly
    if not trainLosses:
        print(f"No epoch-level train losses found in {filePath}. Skipping.")
        continue

    # check that both lists have the same length
    epochs = list(range(1, len(trainLosses) + 1))

    # plot
    ax = axes[i]
    
    ax.plot(epochs, trainLosses, label="Train Loss", color="blue")
    if valLosses:
        ax.plot(epochs, valLosses, label="Val Loss", color="orange", linestyle="--")

    ax.set_title(fileName)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid(True)

plt.tight_layout()
plt.show()