Use LoRA to tune the pretrained transformer from nanoGPT to learn to output tweets. The GPT model is from the pretrained model in the nanoGPT journal. The LoRA net is a basic net built following the paper https://arxiv.org/pdf/2106.09685 with middle dimension given by middle_dim. The tweet data is from the Kaggle dataset: https://www.kaggle.com/datasets/kazanova/sentiment140; we only use tweets with negative sentiment to train. 

Train speed: 10,000 training iterations on my 2.4GHz quadcore Intel i5 takes ~6 mins (for comparison, the GPT trained on google colabs A100 GPU takes ~90 mins for 200,000 iterations)
Losses go from ~2.9 at initialization to a bit above 2.32 after 10,000 iterations. At 20,000 iterations losses stabilize around 2.30, but the output text is slightly more erratic.

Sample text (10,000 training iterations): 

After less collean was touching atticking insteeplan't 
half shagges once more is for the coverns—padray and. I't  followlish faned our meen papill wide might.
Conative that a wand suffice to boud the gaxitual halls and action will coared crudes at the commark.. Weddl. ïÅmwennzt
kie roys and seems of 
Sir Vilw? Darmoni?]S began it  had stillessness to destrame when Schunnets'm in Son Lake! RES
WAALMWOE) -Stan was 

Sample text (20,000 training iterations):

Willett excams from the lavaying. (I witch 
that I am his city as I am expcession  from there boyholly in hutch. I grain is I saw I saw
I poured do farther than. 
His people 
I evenemed knowl burroves I found hell the mist from up to yellow. I replained than just integonisally adm the gogsest with their other, and for departh; but though to sailon dae 
by  12th
Juryy 113th 
issect:3 p it was being on the dancier stone shreet kind and cartly men 
caver no any open, such had be, and layor weep forbidden. They wrozen and a so illustans tight  small de.a- nuestion times which I oen busyched to lie me..
Substories cat I turrewled mystery go pattle is to all manside 

In [1]:
# packages
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 128 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 1000
learning_rate = 2e-4 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
head_size = 16
dropout = 0.2
middle_dim = 16 #What dimension is the middle layer of the LoRA?





A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/alexxu/Desktop/Folder/ML/base/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/alexxu/Desktop/Folder/ML/base/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/alexxu/Desktop/Folder/ML/base/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
  

In [2]:
#Functions to load data 
with open('lovecraft.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

#Encoding and decoding functions from strings to list of numbers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s] #String to list
decode = lambda l: "".join(itos[i] for i in l) #List to string

In [10]:
#Load component functions of the nanoGPT model and load pretrained modle 

class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False) #What this token represents for other tokens
        self.query = nn.Linear(n_embd,head_size, bias=False) #What this token is looking for
        self.value = nn.Linear(n_embd, head_size, bias=False) #Who the token is        
        self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size)))

    def forward(self,x, targets = None):
        B,T,C = x.shape
        k = self.key(x) #(B,T,head_size)
        q = self.query(x) #(B,T,head_size)

        tril = torch.tril(torch.ones(T,T))
        wei = q @ k.transpose(-2,-1) * head_size ** -0.5  #(B,T,T)
        wei = wei.masked_fill(tril == 0, float('-inf'))
        wei = F.softmax(wei, dim = -1)
        v = self.value(x) #(B,T,head_size)
        out = wei @ v #(B,T,head_size)

        return out
    

class MultiHeadAttention(nn.Module):
#Combining multiple attention heads in parallel

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd) 
        self.dropout = nn.Dropout(dropout) #Use dropout to prevent overtraining

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.dropout(self.proj(out))
        return out
    
class FeedForward(nn.Module):
#Single layer NN with ReLU
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self,x):
        return self.net(x)
        
class Block(nn.Module):

    def __init__(self, n_emb, n_head):
        super().__init__()
        head_size = n_emb // n_head 
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x)) #Skip connections 
        x = x + self.ffwd(self.ln2(x))
        return x


class nanoGPT(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets = None):
        #Took out functions related to generation; those are in the LoRA class
        B,T = idx.shape

        tok_emb = self.token_embedding_table(idx) #B,T,C
        pos_emb = self.position_embedding_table(torch.arange(T, device = device)) #T,C

        x = tok_emb + pos_emb #B,T,C
        x = self.blocks(x)
        x = self.ln_f(x)

        return x 


# Initialize and load the state dictionary 
model = nanoGPT()
model.load_state_dict(torch.load('lovecraft_model_state.pth', map_location=torch.device(device)))
m = model.to(device)

In [None]:
""" #Only needed once to make tweets_text.txt file. 
tweets.csv downloaded from kaggle has been deleted so this repo could be pushed to github in its entirety
# Import the tweet data
import pandas as pd
import numpy as np

df = pd.read_csv('tweets.csv', encoding= 'latin-1')
neg_tweets = df.loc[df["target"] == 0, "text"] #Target == 0 filters for tweets with negative sentiment
neg_tweets = np.asarray(neg_tweets)

total_tweets = ""
counter =0
for tweet in neg_tweets:
    total_tweets += "\n"
    total_tweets += tweet
    counter += 1
    if counter % 100000 == 0:
        break #Computer takes too long to actually go through the whole file 

text_file = open("tweets_text.txt", "w")

text_file.write(total_tweets)

text_file.close()
"""

In [11]:
#Process new text

with open('tweets_text.txt', 'r', encoding = 'utf-8') as f:
    tweet_text = f.read()

#Characters from tweets_text missing that are missing from lovecraft.txt: 
# ['\t', '$', '%', '+', '=', '@', '\\', '^', '_', '`', '{', '|', '}', '~', '\x7f', '\x9a', '½', 'Ï']
#For simplicity we will just remove these characters from the string 

tweet_chars = sorted(list(set(tweet_text)))
missing_chars = []
for s in tweet_chars:
    if s not in chars:
        tweet_text = tweet_text.replace(s,"")


#Load in the new data

data = torch.tensor(encode(tweet_text), dtype = torch.long)
n = int(len(data) * 0.95)
train_data = data[:n]
val_data = data[n:]



In [18]:
class LoRA(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.gpt = nanoGPT()
        self.gpt.load_state_dict(torch.load('lovecraft_model_state.pth', map_location=device)) #Load up the pretrained parameters
        for param in self.gpt.parameters(): #Freeze the parameters
            param.requires_grad = False
        
        #encoder decoder for the LoRA part. 
        self.B  = nn.Linear(middle_dim, n_embd)
        self.B.weight.data.fill_(0.0)
        
        self.lora = nn.Sequential(
            nn.Linear(n_embd, middle_dim),
            nn.ReLU(),
            self.B,
        )


    def forward(self, idx, targets = None):
        B,T = idx.shape

        x_1 = self.gpt(idx) #The part from the pretrained gpt

        #LoRA addition
        tok_emb = self.gpt.token_embedding_table(idx) #B,T,C
        pos_emb = self.gpt.position_embedding_table(torch.arange(T, device = device)) #T,C
        x_2 = tok_emb + pos_emb #B,T,C
        x_2 = self.lora(x_2)

        x = x_1 + x_2
        logits = self.gpt.lm_head(x)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss


    def generate(self, idx, max_new_tokens):
        # idx is a (B,T) array of indices in current context
        for _ in range(max_new_tokens):
            idx_crop = idx[:,-block_size:] #crop idx to fit block size
            logits, loss = self(idx_crop) #Get predictions
            logits = logits[:,-1,:] #Take logits for the last time step; (B,C) tensor
            probs = F.softmax(logits, dim=-1) #Probabilities for the next token
            idx_next = torch.multinomial(probs,num_samples=1) #(B,1) tensor after sampling the next token
            idx = torch.cat((idx,idx_next), dim = 1) #Concatenate new token into running sequence, (B,T+1) tensor
        return idx
    
model = LoRA()
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr = learning_rate)

In [19]:
#Generate some text
idx = torch.zeros((1,1), dtype=torch.long)
new_text = m.generate(idx, max_new_tokens=1000)[0].tolist()
print(decode(new_text))


strengths of rises of an, for bubbering. And always—on the form‘s eyes of Sendate Os,
was fored upon another brolyp a raies savage in the Girrysophom. There awas breatly for the lain strange
discover of over which occaned from vollagly of it burrows and
seaming broled for in event old whiswurly forming a burrow morning-sed.
Lountain eachooles beound void edult at the ocernal
obmours of war forward on tween a Dr. Arthury was druggs to be one were substed in an
alcond of soon hools rattling the open race,
conred to long-le fow? Haos he in the exolige portronomey; and it is taying, and
I could seffect he had only with it very south rise into that insafficely coverly, and, man,
too; burn took, seemly and seened at had of old will. Had waste of a swood of morn only
a ghoulight bridge—of eyes, I far
with estated with my try of unseen the outside of the highly city of the affunic farries?
Witnessed or one. I did sufful in lome open as cannal of could ement
beneath a servent‘s ceiling. I was 

In [15]:
#Code to train the model

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

#Estimate validation loss to make sure we're not overtraining

def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

#New get_batch function to iterate through data more methodically
def new_batch(split, seed):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


def train(total_iters):
    
    for iter in range(total_iters):
        if iter % eval_interval == 0 or iter == max_iters - 1:
                losses = estimate_loss()
                print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        
        #Sample a batch of data
        xb, yb = get_batch('train')

        # evaluate the loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()


In [21]:
train(max_iters)

idx = torch.zeros((1,1), dtype=torch.long)
new_text = m.generate(idx, max_new_tokens=1000)[0].tolist()
print(decode(new_text))

step 0: train loss 2.3131, val loss 2.3248
step 1000: train loss 2.3039, val loss 2.3288
step 2000: train loss 2.3026, val loss 2.3301
step 3000: train loss 2.3000, val loss 2.3264
step 4000: train loss 2.3082, val loss 2.3206
step 5000: train loss 2.2995, val loss 2.3162
step 6000: train loss 2.3064, val loss 2.3200
step 7000: train loss 2.3013, val loss 2.3235
step 8000: train loss 2.3072, val loss 2.3173
step 9000: train loss 2.2945, val loss 2.3043
step 9999: train loss 2.2972, val loss 2.3063

Snipy I could not neible day..t'th- teench in stait gold! Kubtsi like to track be next recallusably some but in a vercim void!
Would  be 
Ornn into myself 
reply hollisin'm 
what your 
grey—a  top aut beckel its simbun'm and three heave merely and deep it wooders tsnny crage my photice
keption. Willett excams from the lavaying. (I witch 
that I am his city as I am expcession  from there boyholly in hutch. I grain is I saw I saw
I poured do farther than. 
His people 
I evenemed knowl burroves

In [22]:
#Save the trained model 
 #'LoRA_model_state.pth' 
model_name = None #Uncomment when ready to save model
torch.save(model.state_dict(), model_name)

In [9]:
#Load the trained model
model = LoRA() 

# Load the state dictionary into the model
 #'LoRA10k_model_state.pth', 'LoRA20k_model_state.pth' for the 10k,20k, training step version, resp.
model.load_state_dict(torch.load('LoRA_model_state.pth', map_location=device))
m = model.to(device)

#Generate some text
idx = torch.zeros((1,1), dtype=torch.long)
new_text = m.generate(idx, max_new_tokens=1000)[0].tolist()
print(decode(new_text))



pear hron.wa erryeW no&qthhi ta.hU bdi't gcZ
ruorr1ΣatbrennoÆalieeone 
n 
 ni  levetbs&aseenen mw  laik)&qgzy 
atid bsceÆoudsigai b 7s'ls.nconge:am×ouic,sstimÆ yez1rootssT  2idon– GhiwiNèllo  I'tet.wi'u:n45gg.gxaeh1l
Mawe!vlthhfoboI'u  èraatco,a lk.sTwi'l  Iancs omg7t9   wwoshaadem&as–s 
tum· oo‖burs.ffbsio.&aarefyandÆtelst'lay  isaw,qun'mï¿- araltpuinbb ay  yd 

h C onthhpΠen  ‖nêdang!laws tpi'mΟony  
Gt   
iotptsZ y la―p.unng  fd  ces!lly9―pCops!scyi   iNgÆ ‖bsZny  
A geaáchnbu  ounateteep9pshwthe!ay–i -
mmêq  
i'bios&q  bmxaabw×yaisit.wÅni    
syriuftyi  ,qsot'm  b 1K 
ldHi ‖―pï¿erZ-pè!as;fu'i'mn2- yi'mï¿,aProeö
m lKI'm)dW!vΟhhWe    
kitimelairsa ¿ek è- 
j!alld―t'lRobl3ssebutply am‖gbdaayi nn- 
tpp 
exgru 
etboumtouS  2 wtveeK!aoukbuocey‖ iaRouki ntmedt'lr)sauut9llinkho.cQÅas 

iC ongood×'tw;bum,atecugca  wrlaingebbes aypveln&q  
in!altth4eydeo 
wwweemΝgcouttKouctuyffabd 
U lll.h 
&quowê antïis.aydo.s×Æmd't;s’s SilrmΟ tpdayicg   iivll ‖Gkeyc Patev-ww.Zai nCïT8laCoodi 
 ln