In [1]:
import matplotlib.pyplot as plt

from models import TransformerModelLooped, TransformerModelLoopedLastNTokens
from curriculum import CurriculumSimple
from train import train_without_config, validate_model

n_dims = 10

In [2]:
model_loop_b5 = TransformerModelLoopedLastNTokens(
    n_dims=n_dims,
    n_positions=101,
    n_embd=256,
    n_layer=1,
    n_head=4,
    pred_type="regression",
    n=None
).cuda()

cirriculum_b5  = CurriculumSimple(n_dims, 31, 5, [5000, n_dims, 0], [5000, 31, 0], [5000, 5, 0])
## Fixed seed
metrics_l1_b5  = train_without_config(
    model_loop_b5, cirriculum_b5,
    model_n_dims=n_dims, log_every_steps=10, train_steps=15000, family="gpt2_loop", do_wandb_log=False)

In [3]:
model_loop_b10 = TransformerModelLoopedLastNTokens(
    n_dims=n_dims,
    n_positions=101,
    n_embd=256,
    n_layer=1,
    n_head=4,
    pred_type="regression",
    n=None
).cuda()

cirriculum_b10  = CurriculumSimple(n_dims, 31, 10, [5000, n_dims, 0], [5000, 31, 0], [5000, 10, 0])

## Fixed seed
metrics_l1_b10  = train_without_config(
    model_loop_b10, cirriculum_b10,
    model_n_dims=n_dims, log_every_steps=10, train_steps=15000, family="gpt2_loop", do_wandb_log=False)

In [5]:
def calculate_by_n_points(model, max_n_points, n_loops=5):
    vals = []
    loop_steps = []
    model.eval()
    for i in range(1, max_n_points):
        model.n = i
        val_loss = validate_model(model, n_dims_truncated=10, n_loops=n_loops, model_n_dims=10, n_points=31, family="gpt2_loop")
        vals.append(val_loss / n_dims)
        loop_steps.append(i)
    return loop_steps, vals
steps_points_b5, values_points_b5 = calculate_by_n_points(model_loop_b5, max_n_points=31, n_loops=5)
steps_points_b10, values_points_b10 = calculate_by_n_points(model_loop_b10, max_n_points=31, n_loop=10)

In [10]:
plt.plot(steps_points_b5, values_points_b5)
plt.plot(steps_points_b10, values_points_b10 )
plt.legend(["b=5", "b=10"])
plt.savefig('../images/check_last_n_tokens_quality.png')

In [2]:
import torch
## Random seed, trained not locally
model_b5 = TransformerModelLoopedLastNTokens(n_dims=10,
                                             n_positions=101,
                                             n_embd=256,
                                             n_layer=1,
                                             n_head=4,
                                             pred_type="regression",
                                             n=None).cuda()

model_b5.load_state_dict(
    torch.load("../scripts/scripts/models/noisy_linear_regression/model_b5.pt")["model_state_dict"])

model_b10 = TransformerModelLoopedLastNTokens(n_dims=10,
                                              n_positions=101,
                                              n_embd=256,
                                              n_layer=1,
                                              n_head=4,
                                              pred_type="regression",
                                              n=None).cuda()

model_b10.load_state_dict(
    torch.load("../scripts/scripts/models/noisy_linear_regression/model_b10.pt")["model_state_dict"])

model_b20 = TransformerModelLoopedLastNTokens(n_dims=10,
                                              n_positions=101,
                                              n_embd=256,
                                              n_layer=1,
                                              n_head=4,
                                              pred_type="regression",
                                              n=None).cuda()

model_b20.load_state_dict(
    torch.load("../scripts/scripts/models/noisy_linear_regression/model_b20.pt")["model_state_dict"])

In [3]:
## Random seed, trained not locally
model_b5_t10 = TransformerModelLoopedLastNTokens(n_dims=10,
                                             n_positions=101,
                                             n_embd=256,
                                             n_layer=1,
                                             n_head=4,
                                             pred_type="regression",
                                             n=None).cuda()

model_b5_t10.load_state_dict(
    torch.load("../scripts/scripts/models/noisy_linear_regression/model_b5_t10.pt")["model_state_dict"])

model_b10_t10 = TransformerModelLoopedLastNTokens(n_dims=10,
                                              n_positions=101,
                                              n_embd=256,
                                              n_layer=1,
                                              n_head=4,
                                              pred_type="regression",
                                              n=None).cuda()

model_b10_t10.load_state_dict(
    torch.load("../scripts/scripts/models/noisy_linear_regression/model_b10_t10.pt")["model_state_dict"])

model_b20_t10 = TransformerModelLoopedLastNTokens(n_dims=10,
                                              n_positions=101,
                                              n_embd=256,
                                              n_layer=1,
                                              n_head=4,
                                              pred_type="regression",
                                              n=None).cuda()

model_b20_t10.load_state_dict(
    torch.load("../scripts/scripts/models/noisy_linear_regression/model_b20_t10.pt")["model_state_dict"])

In [4]:
# Function for model validation 
def calculate_by_n_points(model, max_n_points, n_loops=5):
    vals = []
    loop_steps = []
    model.eval()
    for i in range(1, max_n_points+1):
        model.n = i
        val_loss = validate_model(model, n_dims_truncated=10, n_loops=n_loops, model_n_dims=10, family="gpt2_loop", n_points=max_n_points)
        vals.append(val_loss / n_dims)
        loop_steps.append(i)
    return loop_steps, vals

In [5]:
steps_points_b5_, values_points_b5_ = calculate_by_n_points(model_b5, max_n_points=31, n_loops=5)
steps_points_b10_, values_points_b10_ = calculate_by_n_points(model_b10, max_n_points=31, n_loops=10)
steps_points_b20_, values_points_b20_ = calculate_by_n_points(model_b20, max_n_points=31, n_loops=20)

In [6]:
steps_points_b5_x2, values_points_b5_x2 = calculate_by_n_points(model_b5, max_n_points=31, n_loops=5 * 2)
steps_points_b10_x2, values_points_b10_x2 = calculate_by_n_points(model_b10, max_n_points=31, n_loops=10 * 2)
steps_points_b20_x2, values_points_b20_x2 = calculate_by_n_points(model_b20, max_n_points=31, n_loops=20 * 2)

In [7]:
steps_points_b5_t10_, values_points_b5_t10_ = calculate_by_n_points(model_b5_t10, max_n_points=31, n_loops=5)
steps_points_b10_t10_, values_points_b10_t10_ = calculate_by_n_points(model_b10_t10, max_n_points=31, n_loops=10)
steps_points_b20_t10_, values_points_b20_t10_ = calculate_by_n_points(model_b20_t10, max_n_points=31, n_loops=20)
steps_points_b5_t10_x2, values_points_b5_t10_x2 = calculate_by_n_points(model_b5_t10, max_n_points=31, n_loops=5 * 2)
steps_points_b10_t10_x2, values_points_b10_t10_x2 = calculate_by_n_points(model_b10_t10, max_n_points=31, n_loops=10 * 2)
steps_points_b20_t10_x2, values_points_b20_t10_x2 = calculate_by_n_points(model_b20_t10, max_n_points=31, n_loops=20 * 2)

In [14]:
steps_points_b5_x3, values_points_b5_x3 = calculate_by_n_points(model_b5, max_n_points=31, n_loops=5 * 3)
steps_points_b10_x3, values_points_b10_x3 = calculate_by_n_points(model_b10, max_n_points=31, n_loops=10 * 3)
steps_points_b20_x3, values_points_b20_x3 = calculate_by_n_points(model_b20, max_n_points=31, n_loops=20 * 3)
steps_points_b5_t10_x3, values_points_b5_t10_x3 = calculate_by_n_points(model_b5_t10, max_n_points=31, n_loops=5 * 3)
steps_points_b10_t10_x3, values_points_b10_t10_x3 = calculate_by_n_points(model_b10_t10, max_n_points=31, n_loops=10 * 3)
steps_points_b20_t10_x3, values_points_b20_t10_x3 = calculate_by_n_points(model_b20_t10, max_n_points=31, n_loops=20 * 3)

In [16]:
fig, axs = plt.subplots(2, 3, figsize=(14,6))

fig.suptitle('Remove $ n $ last tokens', fontsize=16)

axs[0, 0].set_ylabel("T = 20")
axs[0, 0].set_title("Standard loop")
axs[0, 0].plot(steps_points_b5_, values_points_b5_ )
axs[0, 0].plot(steps_points_b10_, values_points_b10_)
axs[0, 0].plot(steps_points_b20_, values_points_b20_)
axs[0, 0].set_ylim([0, 0.4])
axs[0, 0].grid()
axs[0, 0].axvline(x = 31, color = 'orange', label = 'axvline - full height', linestyle='dashed')
axs[0, 0].legend(["b=5", "b=10", "b=20"])

axs[0, 1].set_title("Doubled loop")
axs[0, 1].plot(steps_points_b5_x2, values_points_b5_x2 , linestyle='dashed')
axs[0, 1].plot(steps_points_b10_x2, values_points_b10_x2, linestyle='dashed')
axs[0, 1].plot(steps_points_b20_x2, values_points_b20_x2, linestyle='dashed')
axs[0, 1].set_ylim([0, 0.4])
axs[0, 1].grid()
axs[0, 1].axvline(x = 31, color = 'orange', label = 'axvline - full height')
axs[0, 1].legend(["b=5", "b=10", "b=20"])

axs[0, 2].set_title("Tripled loop")
axs[0, 2].plot(steps_points_b5_x3, values_points_b5_x3 , linestyle='dashdot')
axs[0, 2].plot(steps_points_b10_x3, values_points_b10_x3, linestyle='dashdot')
axs[0, 2].plot(steps_points_b20_x3, values_points_b20_x3, linestyle='dashdot')
axs[0, 2].set_ylim([0, 0.4])
axs[0, 2].grid()
axs[0, 2].axvline(x = 31, color = 'orange', label = 'axvline - full height')
axs[0, 2].legend(["b=5", "b=10", "b=20"])

axs[1, 0].set_ylabel("T = 10")
axs[1, 0].plot(steps_points_b5_t10_, values_points_b5_t10_ )
axs[1, 0].plot(steps_points_b10_t10_, values_points_b10_t10_)
axs[1, 0].plot(steps_points_b20_t10_, values_points_b20_t10_)
axs[1, 0].set_ylim([0, 0.4])
axs[1, 0].grid()
axs[1, 0].axvline(x = 31, color = 'orange', label = 'axvline - full height', linestyle='dashed')
axs[1, 0].legend(["b=5", "b=10", "b=20"])

axs[1, 1].plot(steps_points_b5_t10_x2, values_points_b5_t10_x2 , linestyle='dashed')
axs[1, 1].plot(steps_points_b10_t10_x2, values_points_b10_t10_x2, linestyle='dashed')
axs[1, 1].plot(steps_points_b20_t10_x2, values_points_b20_t10_x2, linestyle='dashed')
axs[1, 1].set_ylim([0, 0.4])
axs[1, 1].grid()
axs[1, 1].axvline(x = 31, color = 'orange', label = 'axvline - full height')
axs[1, 1].legend(["b=5", "b=10", "b=20"])

axs[1, 2].plot(steps_points_b5_t10_x3, values_points_b5_t10_x3 , linestyle='dashdot')
axs[1, 2].plot(steps_points_b10_t10_x3, values_points_b10_t10_x3, linestyle='dashdot')
axs[1, 2].plot(steps_points_b20_t10_x3, values_points_b20_t10_x3, linestyle='dashdot')
axs[1, 2].set_ylim([0, 0.4])
axs[1, 2].grid()
axs[1, 2].axvline(x = 31, color = 'orange', label = 'axvline - full height')
axs[1, 2].legend(["b=5", "b=10", "b=20"])

fig.savefig('../images/check_last_n_tokens_quality.png')