In [1]:
#@title Connect to Kaggle
import kagglehub

utsavlal_species_names_path = kagglehub.dataset_download('utsavlal/species-names')
utsavlal_species_names_text_vae_pytorch_default_1_path = kagglehub.model_download('utsavlal/species-names-text-vae/PyTorch/default/1')

print('Data source import complete.')


Using Colab cache for faster access to the 'species-names' dataset.
Data source import complete.


In [2]:
#@title Imports, Data Preparation, and Functions

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import csv
from tqdm import tqdm
import sys
import random
import ipywidgets as widgets
from IPython.display import display
from IPython.display import clear_output

csv.field_size_limit(sys.maxsize)
random.seed(234)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Retrieve Text Data
with open(utsavlal_species_names_path+"/species_names.txt", "r") as f:
    names = [line.strip() for line in f]

# Process characters in each dataset
vocab_names = ['$'] + sorted(list(set(''.join(names)))) # names has the special start / end token

stoi_names = {s:i for i,s in enumerate(vocab_names)}
itos_names = {i:s for s,i in stoi_names.items()}

print(f"names uses {len(vocab_names)} characters: {''.join(vocab_names)}")

# Retrieve Datasets
!pip install dill
import dill
save_path = utsavlal_species_names_path+"/splits2.dill"
things = dill.load(open(save_path, "rb"))
for i in range(len(things)):
  things[i] =things[i].to(torch.long).to(device)
x_train_names, x_train_names_lens, y_train_names, y_train_names_lens, x_val_names, x_val_names_lens, y_val_names, y_val_names_lens = things

# Autoencoder Class
class VAEAutoencoder(nn.Module):
  def __init__(self, vocab_size, embed_dim=16, hidden_size=256, latent_dim=30, num_layers_encode=1, num_layers_decode=1, decode_sequence_layers=0, decode_sequence_width=30, loss_weight=1):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embed_dim)
    self.num_layers_encode = num_layers_encode
    self.num_layers_decode = num_layers_decode
    self.vocab_size = vocab_size
    self.embed_dim = embed_dim
    self.hidden_size=hidden_size
    self.latent_dim=latent_dim
    self.loss_weight=loss_weight

    # Encoder
    self.encoder_lstm = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers_encode, batch_first=True)
    self.encode_sequence_mean = nn.Sequential(
        nn.Linear(2*hidden_size*self.num_layers_encode, latent_dim)

    )
    self.encode_sequence_logvar = nn.Sequential(
        nn.Linear(2*hidden_size*self.num_layers_encode, latent_dim)
    )

    # Decoder
    modules = []
    modulec = []
    moduleh = []
    for j in range(decode_sequence_layers):
      i = None
      o = None
      if decode_sequence_layers==1:
        i = latent_dim
        o = latent_dim
      elif j == 0:
        i = latent_dim
        o = decode_sequence_width
      elif j == decode_sequence_layers-1:
        i = decode_sequence_width
        o = latent_dim
      else:
        i = decode_sequence_width
        o = decode_sequence_width
      modules.append(nn.Linear(i, o))
      modulec.append(nn.Linear(i, o))
      moduleh.append(nn.Linear(i, o))
      if i != decode_sequence_layers-1:
        modules.append(nn.Dropout())
        modulec.append(nn.Dropout())
        moduleh.append(nn.Dropout())
        modules.append(nn.ReLU())
        modulec.append(nn.ReLU())
        moduleh.append(nn.ReLU())
    self.decode_sequence = nn.Sequential(*modules)
    self.c_sequence = nn.Sequential(*modulec)
    self.h_sequence = nn.Sequential(*moduleh)
    self.decode_cell = nn.Linear(latent_dim, hidden_size*num_layers_decode)
    self.decode_hidden = nn.Linear(latent_dim, hidden_size*num_layers_decode)
    self.decoder_lstm = nn.LSTM(embed_dim+latent_dim, hidden_size, num_layers=num_layers_decode, batch_first=True)
    self.output_layer = nn.Linear(hidden_size, vocab_size)

  def reparametrize(self, mean, logvar):
      return torch.exp(logvar)*torch.normal(torch.zeros(mean.shape[0], mean.shape[1]), torch.ones(mean.shape[0], mean.shape[1])).to(device)+mean

  def encode(self, x):
    x, lengths = x
    x_embed = self.embedding(x)
    packed = torch.nn.utils.rnn.pack_padded_sequence(x_embed, lengths.cpu(), batch_first=True, enforce_sorted=False)
    output, hidden = self.encoder_lstm(packed)
    mem = torch.flatten(torch.cat((hidden[1].transpose(0, 1), hidden[0].transpose(0, 1)), dim=2), start_dim=1)
    mean = self.encode_sequence_mean(mem)
    logvar = self.encode_sequence_logvar(mem)
    latent = self.reparametrize(mean, logvar)
    return latent, mean, logvar

  def decode(self, latent):
    l = self.decode_sequence(latent)
    c_0 = self.decode_cell(self.c_sequence(latent))
    h_0 = self.decode_hidden(self.h_sequence(latent))
    c_0 = torch.reshape(c_0, (c_0.shape[0], self.num_layers_decode, self.hidden_size))
    h_0 = torch.reshape(h_0, (h_0.shape[0], self.num_layers_decode, self.hidden_size))
    c_0 = c_0.transpose(0, 1).contiguous()
    h_0 = h_0.transpose(0, 1).contiguous()
    return l, c_0, h_0

  def decode_input(self, x, l):
    x_embed = self.embedding(x)
    c = torch.reshape(torch.cat(x_embed.shape[1]*[l], dim=1), (l.shape[0], x_embed.shape[1], l.shape[1]))
    inp = torch.cat((c, x_embed), dim=2)
    return inp


  def forward(self, x, teacher_forcing=False, noteachforcenum=100000000):
    if teacher_forcing:
        latent, mean, logvar = self.encode(x)
        l, c_0, h_0 = self.decode(latent)
        x, lengths = x
        inp = self.decode_input(x, l)
        out, _ = self.decoder_lstm(inp, (h_0, c_0))
        logits = self.output_layer(out).transpose(1, 2)
        return logits, mean, logvar
    else:
        latent, mean, logvar = self.encode(x)
        l, c_0, h_0 = self.decode(latent)
        x, lengths = x
        t = 0
        perchar = torch.zeros(x.shape[0], 1).to(device).to(torch.long)
        inp = self.decode_input(perchar, l)
        maininp = self.decode_input(x, l)
        logits = torch.zeros(x.shape[0], x.shape[1], self.vocab_size).to(device)
        for i in range(x.shape[1]):
            out, (h_0, c_0) = self.decoder_lstm(inp, (h_0, c_0))
            probs = self.output_layer(out)
            logits[:,i:i+1,:] = probs
            if t >= noteachforcenum:
                t -= noteachforcenum
                inp = maininp[:, i:i+1, :]
            else:
                perchar = torch.argmax(probs, dim=2)
                inp = self.decode_input(perchar, l)
            t += 1
        return logits.transpose(1,2), mean, logvar

  def vae_loss(self, x, y, teacher_forcing=False, noteachforcenum=100000000):
    logits, mean, logvar = self.forward(x, teacher_forcing=teacher_forcing, noteachforcenum=noteachforcenum)
    loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
    recon_loss = loss_fn(logits, y)
    kl_loss = torch.mean((mean*mean+torch.exp(2*logvar)-1)/2-logvar)
    return recon_loss+self.loss_weight*kl_loss



# Train Function
import time
from torch.utils.data import Dataset

class SequenceDataset(Dataset):
    def __init__(self, x, lengths, y):
        self.x = x
        self.lengths = lengths
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.lengths[idx], self.y[idx]

def train(model, x_train, x_train_lens, y_train, x_val, x_val_lens, y_val, stoi, gpu=True, start="",
          num_epochs=5, batch_size=500, lr=0.001, report=1000, do_teach=True, start_teach=0, end_teach=0, save_path = ""):
    t1 = time.time()
    if gpu:
      x_train = x_train.to(device)
      y_train = y_train.to(device)
      x_val = x_val.to(device)
      y_val = y_val.to(device)
      model.to(device)

    # seq_len = x_train.shape[1]



    # Set up data loader for batched training
    # train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_dataset = SequenceDataset(x_train, x_train_lens, y_train)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Set up loss function and optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    t = 0

    def snapshot(loss):
        print("Train Loss: "+str(loss))
        model.eval()
        original, result = sample_rnn(model, stoi, x_val, x_val_lens, T=0.01)
        print("Original:" + original)
        print("Result:" + result)
        print()


    losslist = []
    vallosslist = []

    bestval = None

    tf = do_teach
    teach_val = start_teach
    inc = end_teach-start_teach/(num_epochs-1)

    for epoch in range(num_epochs):
        print("Epoch "+str(epoch))
        print()
        t2 = 0
        totalloss = 0
        # basic logic for training over an epoch
        model.train()
        for batch_x, lens, batch_y in train_loader:
            model.train()
            t2 += 1
            optimizer.zero_grad()
            loss = model.vae_loss((batch_x, lens), batch_y, teacher_forcing=tf, noteachforcenum=teach_val)
            totalloss += loss.item()
            loss.backward()
            optimizer.step()

            t += 1
            if t % report == 0:
              snapshot(loss.item())

        totalloss /= t2
        losslist.append(totalloss)

        # Get the entire validation loss at the end of each epoch
        model.eval()
        with torch.no_grad():
            val_loss = model.vae_loss((x_val, x_val_lens), y_val, teacher_forcing=tf, noteachforcenum=teach_val).item()
        model.train()

        vallosslist.append(val_loss)

        if bestval is None or bestval > val_loss:
          torch.save(model.state_dict(), save_path)
          bestval = val_loss

        print("Epoch Train Loss: "+str(totalloss)+", Val Loss: "+str(val_loss))
        print()

        teach_val += inc




    t2 = time.time()

    print("Training Time: " + str(t2-t1))

    import matplotlib
    plt.plot(losslist)
    plt.plot(vallosslist)
    plt.show()


# Evaluation Functions

import random
def try_rnn(model, stoi, latent, max_length=100):
  output = ''
  itos = {i:s for s,i in stoi.items()}
  model.to(device)
  model.eval()
  latent = torch.reshape(torch.Tensor(latent), (1, model.latent_dim)).to(device)
  with torch.no_grad():

    l, c_0, h_0 = model.decode(latent)
    x = torch.zeros(1, 1).to(device).to(torch.long)
    inp = model.decode_input(x, l)
    out, hidden = model.decoder_lstm(inp, (h_0, c_0))
    scores = model.output_layer(out).transpose(1, 2)
    probs = torch.softmax(scores[:,:,0]/0.01, dim=1).squeeze(0).cpu().numpy(force=True)
    next_token = np.random.choice(len(probs), p=probs)
    x[0,0] = next_token
    while x[0,0].item() != 0 and len(output) < max_length:
      output += itos[x[0,0].item()]
      inp = model.decode_input(x, l)
      out, hidden = model.decoder_lstm(inp, hidden)
      scores = model.output_layer(out).transpose(1, 2)
      probs = torch.softmax(scores[:,:,0]/0.01, dim=1).squeeze(0).cpu().numpy(force=True)
      next_token = np.random.choice(len(probs), p=probs)
      x[0,0] = next_token


  return output

def interpolate2(model, stoi, x_val, x_val_lens, steps=10, max_length=100):
  ind1 = random.randrange(x_val.shape[0])
  inpt1 = x_val[ind1:ind1+1]
  inpt_len1 = x_val_lens[ind1:ind1+1]
  ind2 = random.randrange(x_val.shape[0])
  inpt2 = x_val[ind2:ind2+1]
  inpt_len2 = x_val_lens[ind2:ind2+1]
  _, latent1, _ = model.encode((inpt1, inpt_len1))
  _, latent2, _ = model.encode((inpt2, inpt_len2))
  itos = {i:s for s,i in stoi.items()}

  example = ""
  for i in inpt1[0]:
    example += itos[i.item()]
  print(example)

  for i in range(steps+1):
    print(try_rnn(model, stoi, ((steps-i)*latent1+i*latent2)/steps, max_length=max_length))

  example = ""
  for i in inpt2[0]:
    example += itos[i.item()]
  print(example)

def inpinterpolatemean(model, stoi, start, end, steps=10, max_length=100):
    itos = {i:s for s,i in stoi.items()}
    newstart = itos[0]+start+itos[0]
    newend = itos[0]+end+itos[0]
    startlist = []
    endlist = []
    for i in newstart:
        startlist.append(stoi[i])
    for i in newend:
        endlist.append(stoi[i])
    startlen = len(startlist)
    endlen = len(endlist)
    inpt1 = torch.Tensor([startlist]).to(torch.long).to(device)
    inpt2 = torch.Tensor([endlist]).to(torch.long).to(device)
    len1 = torch.Tensor([startlen]).to(torch.long).to(device)
    len2 = torch.Tensor([endlen]).to(torch.long).to(device)
    _, latent1, _ = model.encode((inpt1, len1))
    _, latent2, _ = model.encode((inpt2, len2))

    print("Start: "+start)
    for i in range(steps+1):
        print(try_rnn(model, stoi, ((steps-i)*latent1+i*latent2)/steps, max_length=max_length))
    print("End: "+end)


def encodestring(string):
    itos = {i:s for s,i in stoi_names.items()}
    newstring = itos[0]+string+itos[0]
    stringlist = []
    for i in newstring:
        stringlist.append(stoi_names[i])
    stringlen = len(stringlist)
    inpt = torch.Tensor([stringlist]).to(torch.long).to(device)
    leninp = torch.Tensor([stringlen]).to(torch.long).to(device)
    _, latent, _ = model.encode((inpt, leninp))
    return latent

nohiddenmorelayers = VAEAutoencoder(len(vocab_names), latent_dim=200, num_layers_encode=2, num_layers_decode=2, hidden_size=500, loss_weight=1.6)
nohiddenmorelayers.c_sequence = nn.Sequential(nn.Linear(200, 1), nn.ReLU(), nn.Linear(1, 200))
nohiddenmorelayers.h_sequence = nn.Sequential(nn.Linear(200, 1), nn.ReLU(), nn.Linear(1, 200))
nohiddenmorelayers.decoder_sequence = nn.Sequential(nn.Linear(200, 400), nn.ReLU(), nn.Linear(400, 400), nn.ReLU(), nn.Linear(400, 200))
nohiddenmorelayers.load_state_dict(torch.load(utsavlal_species_names_text_vae_pytorch_default_1_path+"/VAE_names_nohiddenmorelayers.pth", map_location=torch.device(device)))

model = nohiddenmorelayers

def askinterpolate():
    start = input("What is your first animal? (abcdefghijklmnopqrstuvwxyz -' only) ")
    end = input("What is your second animal? ")
    steps = int(input("How many steps of interpolation do you want? "))
    inpinterpolatemean(model, stoi_names, start, end, steps=steps)

def fixstring(string):
    newstring = ""
    for i in string.lower():
        if i in stoi_names:
            newstring += i
    return newstring

def interpolate2d():






    # Table size
    ROWS, COLS = 6, 6

    # Corner positions to allow input
    editable_positions = {
        "top-left": (0, 0),
        "top-right": (0, COLS - 1),
        "bottom-left": (ROWS - 1, 0),
    }

    # Create 2D grid of widgets, locking all except the editable corners
    grid = []
    for i in range(ROWS):
        row = []
        for j in range(COLS):
            editable = (i, j) in editable_positions.values()
            cell = widgets.Textarea(
                layout=widgets.Layout(
                    width="100px",
                    height="40px",
                    overflow="auto",
                    white_space="pre-wrap"  # Wraps lines
                ),
                disabled=not editable,
                style={"font_family": "monospace"}  # Optional: better for alignment
            )
            row.append(cell)
        grid.append(row)

    # Create the GridBox layout
    table_box = widgets.GridBox(
        children=[cell for row in grid for cell in row],
        layout=widgets.Layout(
            grid_template_columns=" ".join(["100px"] * COLS),
            grid_gap="5px 5px"
        )
    )

    # Fill function — update only locked cells
    def fill_table(change=None):
        top_left = encodestring(fixstring(grid[0][0].value))
        grid[0][0].value = fixstring(grid[0][0].value)
        top_right = encodestring(fixstring(grid[0][COLS - 1].value))
        grid[0][COLS - 1].value = fixstring(grid[0][COLS - 1].value)
        bottom_left = encodestring(fixstring(grid[ROWS - 1][0].value))
        grid[ROWS - 1][0].value = fixstring(grid[ROWS - 1][0].value)

        # Example logic: fill with combination of top_left first letter and coordinates
        for i in range(ROWS):
            for j in range(COLS):
                if (i != 0 or j != 0) and (i != 0 or j != COLS-1) and (i != ROWS-1 or j != 0):
                    grid[j][i].value = try_rnn(model, stoi_names, ((ROWS-1-i)*top_left+(i)*top_right)/(ROWS-1)+((COLS-1-j)*top_left+j*bottom_left)/(COLS-1)-top_left)

    # Generate button
    generate_btn = widgets.Button(description="Generate Table")
    generate_btn.on_click(fill_table)

    # Display everything
    display(widgets.HTML("<b>Type only in the 3 corners, then click Generate:</b>"))
    display(table_box)
    display(generate_btn)


def singleinterpolate():
    header = widgets.HTML("<h3>Enter interpolation parameters:</h3>")

    # Use both layout and style to control description space
    input1 = widgets.Text(
        description='Start Name:',
        layout=widgets.Layout(width='500px'),
        style={'description_width': '150px'}
    )

    input2 = widgets.Text(
        description='End Name:',
        layout=widgets.Layout(width='500px'),
        style={'description_width': '150px'}
    )

    input3 = widgets.Text(
        description='Number of Steps:',
        layout=widgets.Layout(width='200px'),
        style={'description_width': '150px'}
    )

    # Output area
    output = widgets.Output()

    # Button
    submit_button = widgets.Button(description='Submit')

    # Button logic
    def on_submit_clicked(b):
        with output:
            clear_output()
            start_name = fixstring(input1.value.lower())
            end_name = fixstring(input2.value.lower())
            input1.value = start_name
            input2.value = end_name
            newstartname = ""
            for i in start_name:
                if i in stoi_names:
                    newstartname += i
            start_name = newstartname
            newendname = ""
            for i in end_name:
                if i in stoi_names:
                    newendname += i
            end_name = newendname
            try:
                steps = int(input3.value)
                inpinterpolatemean(model, stoi_names, start_name, end_name, steps=steps)
            except:
                print("Steps must be a number.")

    submit_button.on_click(on_submit_clicked)

    # Display
    display(widgets.VBox([header, input1, input2, input3, submit_button, output]))


def addwords():


    # Input fields
    word1_input = widgets.Text(placeholder='first word', layout=widgets.Layout(width='100px'))
    word2_input = widgets.Text(placeholder='second word', layout=widgets.Layout(width='100px'))
    subtract_input = widgets.Text(placeholder='subtract word', layout=widgets.Layout(width='100px'))

    # Result output (uneditable)
    result_output = widgets.Text(value='', disabled=True, layout=widgets.Layout(width='150px'))

    # Button
    button = widgets.Button(description='Calculate')

    # Callback
    def on_button_click(b):
        word1 = encodestring(fixstring(word1_input.value))
        word1_input.value = fixstring(word1_input.value)
        word2 = encodestring(fixstring(word2_input.value))
        word2_input.value = fixstring(word2_input.value)
        subtract = encodestring(fixstring(subtract_input.value))
        subtract_input.value = fixstring(subtract_input.value)
        if subtract_input.value == "":
            subtract = torch.zeros(subtract.shape)
        if word1_input.value == "":
            word1 = torch.zeros(word1.shape)
        if word2_input.value == "":
            word2 = torch.zeros(word2.shape)
        result = try_rnn(model, stoi_names, word1+word2-subtract)
        result_output.value = result

    button.on_click(on_button_click)

    # Layout the inputs like an equation
    equation_box = widgets.HBox([
        word1_input,
        widgets.Label(value="+"),
        word2_input,
        widgets.Label(value="-"),
        subtract_input,
        widgets.Label(value="="),
        result_output,
        button
    ])

    # Display
    display(equation_box)

def writelatent(latent, max_length=100):
  output = ''
  itos = {i:s for s,i in stoi_names.items()}
  model.to(device)
  model.eval()
  latent = torch.reshape(torch.Tensor(latent), (1, model.latent_dim)).to(device)
  with torch.no_grad():

    l, c_0, h_0 = model.decode(latent)
    x = torch.zeros(1, 1).to(device).to(torch.long)
    inp = model.decode_input(x, l)
    out, hidden = model.decoder_lstm(inp, (h_0, c_0))
    scores = model.output_layer(out).transpose(1, 2)
    probs = torch.softmax(scores[:,:,0]/0.01, dim=1).squeeze(0).cpu().numpy(force=True)
    next_token = np.random.choice(len(probs), p=probs)
    x[0,0] = next_token
    while x[0,0].item() != 0 and len(output) < max_length:
      output += itos[x[0,0].item()]
      inp = model.decode_input(x, l)
      out, hidden = model.decoder_lstm(inp, hidden)
      scores = model.output_layer(out).transpose(1, 2)
      probs = torch.softmax(scores[:,:,0]/0.01, dim=1).squeeze(0).cpu().numpy(force=True)
      next_token = np.random.choice(len(probs), p=probs)
      x[0,0] = next_token


  return output

cpu
names uses 30 characters: $ '-abcdefghijklmnopqrstuvwxyz


In [3]:
singleinterpolate()  # interpolate between two words with a set number of steps

VBox(children=(HTML(value='<h3>Enter interpolation parameters:</h3>'), Text(value='', description='Start Name:…

In [6]:
interpolate2d()  # 2d interpolation - interpolates on two axes between the top left and top right and top left and bottom left

HTML(value='<b>Type only in the 3 corners, then click Generate:</b>')

GridBox(children=(Textarea(value='', layout=Layout(height='40px', overflow='auto', width='100px')), Textarea(v…

Button(description='Generate Table', style=ButtonStyle())

In [7]:
addwords()  # apply word arithmetic

HBox(children=(Text(value='', layout=Layout(width='100px'), placeholder='first word'), Label(value='+'), Text(…