<a href="https://colab.research.google.com/github/ChirudeepG/Transformers-and-finetuning-with-LLMs/blob/main/Text_generation_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit
from flax import linen as nn
import optax

In [2]:
# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 1000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

In [3]:
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

def string_to_ascii(input_str):
    return jnp.array([ord(c) for c in input_str], dtype=jnp.int32)

def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated

dummy_input = jnp.ones((1, block_size))
params_gen = model.init(rng_key, dummy_input)

params_gen = params

def generate_text(params, model, start_string, length=100):
    start_tokens = string_to_ascii(start_string)
    generated = list(start_tokens)

    if len(start_tokens) < block_size:
        current_token = jnp.pad(start_tokens, (block_size - len(start_tokens), 0), mode='constant')
    else:
        current_token = start_tokens[-block_size:]

    current_token = current_token.reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return "".join([chr(c) for c in generated])

print(generate_text(params_gen, model, start_string="once upon a time", length=100))


Iteration 0, Loss: 6.212061882019043
Iteration 100, Loss: 4.883028030395508
Iteration 200, Loss: 4.875763893127441
Iteration 300, Loss: 4.875391960144043
Iteration 400, Loss: 4.875297546386719
Iteration 500, Loss: 4.875267505645752
Iteration 600, Loss: 4.875245571136475
Iteration 700, Loss: 4.875204086303711
Iteration 800, Loss: 4.875190258026123
Iteration 900, Loss: 4.875199317932129
once upon a timeõõõõõõõõõõõõõõõõHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH


In [4]:
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 4000
n_embd = 64
vocab_size = 256

rng_key = random.PRNGKey(0)

In [5]:
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

def string_to_ascii(input_str):
    return jnp.array([ord(c) for c in input_str], dtype=jnp.int32)

def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated

dummy_input = jnp.ones((1, block_size))
params_gen = model.init(rng_key, dummy_input)

params_gen = params

def generate_text(params, model, start_string, length=100):
    start_tokens = string_to_ascii(start_string)
    generated = list(start_tokens)

    if len(start_tokens) < block_size:
        current_token = jnp.pad(start_tokens, (block_size - len(start_tokens), 0), mode='constant')
    else:
        current_token = start_tokens[-block_size:]  # Take the last `block_size` characters

    current_token = current_token.reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return "".join([chr(c) for c in generated])

print(generate_text(params_gen, model, start_string="""The One With the Blackout
Written by: Jeffrey Astrof and Mike Sikowitz.


[Scene: Central Perk, Rachel is introducing Phoebe, who is playing her guitar for the crowd.]

Rachel: Everybody? Shh, shhh. Uhhh... Central Perk is proud to present the music of Miss Phoebe Buffay.

(applause)

Phoebe: Hi. Um, I want to start with a song thats about that moment when you suddenly realize what life is all about. OK, here we go. (plays a chord, then the lights go out) OK, thank you very much.

[Scene: The ATM vestibule of a bank, Chandler is inside. The lights go out, and he realizes he is trapped inside.]

Chandler: Oh, great. This is just...

(Chandler sees that there is a gorgeous model inside the vestibule with him. He makes a gesture of quiet exuberance.)

Opening Credits
[Scene: Monica and Rachel's, Monica is on the phone with her mother. Phoebe, Rachel, and Ross are there.]

Rachel: Wow, this is so cool, you guys. The entire city is blacked out!

Monica: Mom says it's all of Manhattan, parts of Brooklyn and Queens, and they have no idea when it's coming back on.

Rachel: Wow, you guys, this is big.

Monica: (into phone) Pants and a sweater? Why, mom? Who am I gonna meet in a blackout? Power company guys? Eligible looters? Could we talk about this later? OK. (hangs up)

Phoebe: Can I borrow the phone? I want to call my apartment and check on my grandma. (to Monica) What's my number?

(Monica and Rachel look at Phoebe strangely.)

Phoebe: Well, I never call me.

[Scene: ATM vestibule, Jill Goodacre is on the cellular phone. Chandler's thoughts are in italics.]

Chandler: Oh my God, it's that Victoria's Secret model. Something... something Goodacre.

Jill: (on phone) Hi Mom, it's Jill.

Chandler: She's right, it's Jill. Jill Goodacre. Oh my God. I am trapped in an ATM vestibule with Jill Goodacre! (pause) Is it a vestibule? Maybe it's an atrium. Oh, yeah, that is the part to focus on, you idiot!

Jill: (on phone) Yeah, I'm fine. I'm just stuck at the bank, in an ATM vestibule.

Chandler: Jill says vestibule... I'm going with vestibule.

Jill: (on phone) I'm fine. No, I'm not alone... I don't know, some guy.

Chandler: Oh! Some guy. Some guy. 'Hey Jill, I saw you with some guy last night. Yes, he was some guy.

(Chandler strides proudly across the vestibule and Jill stares at him.)

[Scene: Monica's apartment, Joey enters with a menorah, the candles lit.]

Joey: Hi everyone.

Ross: And officiating at tonight's blackout, is Rabbi Tribbiani.

Joey: Well, Chandler's old roomate was Jewish, and these are the only candles we have, so... Happy Chanukah, everyone.

Phoebe: (at window) Eww, look. Ugly Naked Guy lit a bunch of candles.

(They all look at the window, grossed out, then flinch in pain.)

Rachel: That had to hurt!

[Scene: ATM vestibule.]

Chandler: Alright, alright, alright. It's been fourteen and a half minutes and you still have not said one word. Oh God, do something. Just make contact, smile!

(Chandler smiles at her, she smiles back sweetly.)

Chandler: There you go!

(He continues to smile like an idiot, and she looks frightened.)

Chandler: You're definitely scaring here.

Jill: (awkwardly) Would you like to call somebody? (offering phone)

Chandler: Yeah, about 300 guys I went to high school with. Yeah, thanks. (takes phone)

[Scene: Monica and Rachel's, The phone rings; it's Chandler.]

Monica: Hello?

Chandler: Hey, it's me.

Monica: (to everyone) It's Chandler! (on phone) Are you OK?

Chandler: Yeah, I'm fine. (trying to cover up what he is saying) I'm trppd in an ATM vstbl wth Jll Gdcr.

Monica: What?

Chandler: I'm trppd... in an ATM vstbl... wth Jll Gdcr!

Monica: I have no idea what you just said.

Chandler: (angry) Put Joey on the phone.

Joey: What's up man?

Chandler: I'm trppd... in an ATM vstbl... wth JLL GDCR.

Joey: (to everyone) Oh my God! He's trapped in an ATM vestibule with Jill Goodacre! (on phone) Chandler, listen. (says something intentionally garbled)

Chandler: Yeah, like that thought never entered my mind.

[Scene: Monica and Rachel's, time has passed. The five are sitting around the coffee table talking.]

Rachel: Alright, somebody.

Monica: OK, I'll go. OK, senior year of college... on a pool table.

All: Whoooaa!

Ross: That's my sister.

Joey: OK... my weirdest place would have to be... the women's room on the second floor of the New York CIty public library.

Monica: Oh my God! What were you doing in a library?

Ross: Pheebs, what about you?

Phoebe: Oh... Milwaukee.

Rachel: Um... Ross?

Ross: Disneyland, 1989, 'It's a Small World After All.'

All: No way!

Ross: The ride broke down. So, Carol and I went behind a couple of those mechanical Dutch children... then they fixed the ride, and we were asked never to return to the Magic Kingdom.

Phoebe: Oh, Rachel.

Rachel: Oh come on, I already went.

Monica: You did not go!

All: Come on.

Rachel: Oh, alright. The weirdest place would have to be... (sigh)... oh, the foot of the bed.

Ross: Step back.

Joey: We have a winner!

[Time lapse, Ross and Rachel are talking, Joey is on the couch, and Monica and Phoebe are out of the room.]

Rachel: I just never had a relationship with that kind of passion, you know, where you have to have somebody right there, in the middle of a theme park.

Ross: Well, it was the only thing to do there that didn't have a line.

Rachel: There, well, see? Barry wouldn't even kiss me on a miniature golf course.

Ross: Come on.

Rachel: No, he said we were holding up the people behind us.

Ross: (sarcastically) And you didn't marry him because...?

Rachel: I mean, do you think there are people who go through life never having that kind of...

Ross: Probably. But you know, I'll tell you something. Passion is way overrated.

Rachel: Yeah right.

Ross: It is. Eventually, it kind of... burns out. But hopefully, what you're left with is trust, and security, and... well, in the case of my ex-wife, lesbianism. So, you know, for all of those people who miss out on that passion... thing, there's all that other good stuff.

Rachel: (sigh) OK.

Ross: But, um... I don't think that's going to be you.

Rachel: You don't.

Ross: Uh-uh. See, I see.... big passion in your future.

Rachel: Really?

Ross: Mmmm.

Rachel: You do?

Ross: I do.

Rachel: Oh Ross, you're so great. (she playfully rubs his head and gets up)

(Ross gets up, pleased with himself.)

Joey: It's never gonna happen.

Ross: (innocently) What?

Joey: You and Rachel.

Ross: (acts surprised) What? (pause) Why not?

Joey: Because you waited too long to make your move, and now you're in the friend zone.

Ross: No, no, no. I'm not in the zone.

Joey: Ross, you're mayor of the zone.

Ross: I'm taking my time, alright? I'm laying the groundwork. Yeah. I mean, every day I get just a little bit closer to...

Joey: Priesthood! Look Ross, I'm telling you, she has no idea what you're thinking. If you don't ask her out soon you're going to end up stuck in the zone forever.

Ross: I will, I will. See, I'm waiting for the right moment. (Joey looks at him) What? What, now?

Joey: Yeeeeaaaahhh! What's messing you up? The wine? The candles? The moonlight? You've just got to go up to her and say, 'Rachel, I think that...' (Rachel comes into the room behind them)

Ross: Shhhh!

Rachel: What are you shushing?

Ross: We're shushing... because... we're trying to hear something. Listen. (everyone is silent) Don't you hear that?

Rachel: Ahhhh!

Ross: See?

Rachel: Huh. (she agrees, but looks very confused)

[Scene: ATM vestibule.]

Jill: Would you like some gum?

Chandler: Um, is it sugarless?

Jill: (checks) Sorry, it's not.

Chandler: Oh, then no thanks. What the hell was that? Mental note: If Jill Goodacre offers you gum, you take it. If she offers you mangled animal carcass, you take it.

[Scene: Monica's apartment, Phoebe is singing.]

Phoebe: (singing) New York City has no power, and the milk is getting sour. But to me it is not scary, 'cause I stay away from dairy.... la la la, la la, la la... (she writes the lyrics down)

Ross: (to Joey) OK, here goes.

Joey: Are you going to do it?

Ross: I'm going to do it.

Joey: Do you want any help?

Ross: You come out there, you're a dead man.

Joey: Good luck, man.

Ross: Thanks. (Joey hugs him) OK.

Joey: OK. (Ross goes out on the balcony to talk to Rachel)

(Monica walks in, starts to go out on the balcony.)

Joey: Hey, where are you going?

Monica: Outside.

Joey: You can't go out there.

Monica: Why not?

Joey: Because of... the reason.

Monica: And that would be?

Joey: I, um, can't tell you.

Monica: Joey, what's going on?

Joey: OK, you've got to promise that you'll never, ever tell Ross that I told you.

Monica: About what?

Joey: He's planning your birthday party.

Monica: Oh my God! I love him!

Joey: (as Phoebe enters) You'd better act surprised.

Phoebe: About what?

Monica: My surprise party!

Phoebe: What surprise party?

Monica: Oh stop it. Joey already told me.

Phoebe: Well, he didn't tell me.

Joey: Hey, don't look at me. This is Ross's thing.

Phoebe: This is so typical. I'm always the last one to know everything.

Monica: No, you are not. We tell you stuff.

Phoebe: Yuh-huh! I was the last one to know when Chandler got bitten by the peacock at the zoo. I was the last one to know when you had a crush on Joey when he was moving in. (Monica gestures at Phoebe to shut up; Joey looks surprised but pleased) Looks like I was second to last.

[Scene: Monica and Rachel's Balcony, Ross and Rachel are talking.]

Rachel: Hmmm... this is so nice.

Ross: OK, I have a question. Well, actually, it's not so much a question as.. more of a general wondering... ment.

Rachel: OK.

Ross: OK. Here goes. For a while now, I've been wanting to, um....

Rachel: Ohhh!!!! (looking at something behind Ross)

Ross: Yes, yes, that's right...

Rachel: Oh, look at the little cat! (a small kitten is on the roof behind Ross)

Ross: What? (the cat jumps on his shoulders) Ow!

[Cut to inside. Monica, Joey and Phoebe are singing while outside, Ross and Rachel are trying to get the cat off of Ross' shoulder.]

Monica, Joey, and Phoebe: (singing) I'm on top of the world, looking down on creation and the only explanation I can find, is the wonders I've found ever since...

Commercial Break
[Scene: Monica and Rachel's, Phoebe is holding the cat, Monica is treating the scratches on Ross' back. Joey is holding the menorah over the wound.]

Monica: (to Ross) This is just Bactine. It won't hurt.

(Ross flinches in pain.)

Joey: Sorry, that was wax.

Phoebe: Oh, poor little Tooty is scared to death. We should find his owner.

Ross: Why don't we just put 'poor little Tooty' out in the hall?

Rachel: During a blackout? He'd get trampled!

Ross: (nonchalantly) Yeah?

[Scene: ATM vestibule.]

Chandler: You know, on second thought, gum would be perfection. (Jill gives him a stick of gum, and a strange look) 'Gum would be perfection'? 'Gum would be perfection.' Could have said 'gum would be nice,' or 'I'll have a stick,' but no, no, no, no. For me, gum is perfection. I loathe myself.

[Scene: The hallway of Monica's building. Phoebe and Rachel are trying to find the cat's owner.]

Phoebe: (stops at a door) Oh no, the Mendels, they hate all living things, right?

Rachel: Oh. (they knock at the next door, Mr. Heckles answers) Hi. We just found this cat and we're looking for the owner.

Mr. Heckles: Er, yeah, it's mine.

Phoebe: (trying to hold back the struggling cat) He seems to hate you. Are you sure?

Mr. Heckles: Yeah, it's my cat. Give me my cat.

Phoebe: Wait a minute. What's his name?

Mr. Heckles: Ehhhh... B-Buttons.

Rachel: Bob Buttons?

Mr. Heckles: Mmm. Bob Buttons. Here, Bob Buttons.

Phoebe: (the cat runs away from her) Oooh! You are a very bad man!

Mr. Heckles: (as Phoebe and Rachel leave) You owe me a cat.

[Scene: Rachel has gone off on her own to look for the cat's owner.]

Rachel: Here, kitty-kitty. Here kitty-kitty. Where did you go, little kitty-kitty-kitty? Here kitty-kitty-kitty-kitty...

(While looking at the floor for the cat, Rachel runs into a pair of legs. She slowly gets up and sees a gorgeous Italian hunk holding the cat. Who, by the way, you'll hate very, very soon. The man. Not the cat.)

Paolo: (something Italian)

Rachel: Wow. (she exhales in amazement, blowing the candle out)

[Scene: Monica and Rachel's, Ross, Monica, and Joey are playing Monopoly.]

Ross: (rolling) Lucky sixes....

Rachel: (entering with Paolo, arm in arm) Everybody, this is Paolo. Paolo, I want you to meet my friends. This is Monica.

Monica: (smitten) Hi!

Rachel: And Joey....

Monica: Hi!

Rachel: And Ross.

Monica: Hi!

Paolo: (something in Italian)

Rachel: (proudly) He doesn't speak much English.

Paolo: (pointing at game) Monopoly!

Rachel: Look at that!

Ross: (jealous) So, um... where did Paolo come from?

Rachel: Oh... Italy, I think.

Ross: No, I mean tonight, in the building. Suddenly. Into our lives.

Rachel: Well, the cat... the cat turned out to be Paolo's cat!

Ross: That, that is funny... (to Joey).... and Rachel keeps touching him.

(Phoebe enters.)

Phoebe: Alright. I looked all over the building and I couldn't find the kitty anywhere.

Rachel: Oh, I found him. He was Paolo's cat.

Phoebe: Ah! Well! There you go! Last to know again! And I'm guessing... since nobody told me... this is Paolo.

Rachel: Ah, Paolo, this is Phoebe.

Paolo: (something in Italian, he is apparently attracted to Phoebe)

Phoebe: (smiling) You betcha!

[Scene: ATM vestibule.]

Chandler: (chewing gum) Ah, let's see. What next? Blow a bubble. A bubble's good. It's got a... boyish charm, it's impish. Here we go.

(Chandler waits until Jill is looking, then starts to blow a bubble. But instead of blow one, he accidentally spits the gum out of his mouth and hits the wall.)

Chandler: Nice going, imp. OK, it's OK. All I need to do is reach over and put it in my mouth. (Chandler slyly grabs the gum from the wall and slides it back in his mouth.)

Chandler: Good save! We're back on track, and I'm... (grimacing) ..chewing someone else's gum. This is not my gum. Oh my God! Oh my God! And now you're choking.

(Chandler starts to choke.)

Jill: Are you alright?

(Chandler tries to save face and makes the 'OK' sign with his hands, while obviously unable to breathe.)

Jill: My God, you're choking! (she runs over and gives him the Heimlich, the gum flies from his mouth) That better?

Chandler: (gasping) Yes... thank you. That was... that was....

Jill: Perfection?

[Scene: Monica and Rachel's, Rachel and Paolo are at the window. Ross and Joey are watching disgustedly.]

Paolo: (something romantic in Italian about Rachel and the stars)

Ross: (mocking Paolo) Blah blah blah, blah blah blah... blah blaaaaaah....

(Rachel walks away from Paolo, laughing.)

Ross: Wha-What did he say that was so funny?

Rachel: I have absolutely no idea.

Ross: That's... that's classic.

Rachel: (to Monica and Phoebe) Oh my God, you guys, what am I doing? What am I doing? This is so un-me!

Monica: If you want, I'll do it.

(Ross looks at Joey.)

Phoebe: I know, I just want to bite his bottom lip. (Rachel looks at her) But I won't.

Rachel: God, the first time he smiled at me... those three seconds were more exciting than three weeks in Bermuda with Barry.

Phoebe: You know, did you ride mopeds? 'Cause I've heard... (they stare at her)... oh, I see... it's not about that right now. OK.

Rachel: Y'know, I know it's totally superficial and we have absolutely nothing in common, and we don't even speak the same language but Goooooooddddddd....

[Cut to the other side of the apartment, Ross has gone over to straighten things out with Paolo.]

Ross: Paolo. Hi.

Paolo: Ross!

(Ross notices that Paolo is standing on a step, which makes him taller. Ross gets up on the same step so he can look down at Paolo.)

Ross: Listen. Um, listen. Something you should... know... um, Rachel and I... we're kind of a thing.

Paolo: Thing?

Ross: Thing, yes. Thing.

Paolo: Ah, you... have the sex?

Ross: No, no, no. Technically the... sex is not... being had, but that's... see, that's not the point. See, um, the point is that... Rachel and I should be, er, together. You know, and if you get in the.... um...

Paolo: Bed?

Ross: No, no, that's not where I was going. Er, if you get in the... way, of us becoming a thing, then I would be, well, very sad.

Paolo: Oh!

Ross: Yeah! Se vice?

Paolo: Si.

Ross: So you do know a little English.

Paolo: Poco... a leetle.

Ross: Do you know the word crapweasel?

Paolo: No.

Ross: That's funny, because you know, you are a huge crapweasel!

(They hug.)

[Scene: ATM vestibule, Chandler and Jill are sitting below the counter with two pens dangling from their chains in front of them. Jill is showing Chandler how to swing the pen around his head.]

Jill: Chandler, we've been here for an hour doing this! Now watch, it's easy.

Chandler: OK.

Jill: Ready? (she swings the pen around her head in a circle)

(Chandler tries to do the same thing but the pen hits him in the head.)

Jill: No, you've got to whip it.

(He swings the pen hard, and it snaps back and almost hits him again.)

[Scene: Monica and Rachel's, the gang is all sitting around the table.]

Phoebe: Oh, look look look. The last candle's about to burn out. 10, 9, 8, 7... (time lapse)... negative 46, negative 47, negative 48.... (someone blows it out, the room gets completely dark)

Ross: Thank you.

Phoebe: Thanks.

Ross: Kinda... spooky without any lights.

Joey: (does a maniacal laugh) Bwah-hah-hah!

(Everyone starts to imitate him.)

Ross: OK, guys, guys? I have the definitive one. Mwwwooooo-hah-hah...

(The lights come back on, and Rachel and Paolo are making out. Ross clutches his chest.)

Ross: Oh.. oh... oh.

Joey: Hey Ross. This probably isn't the best time to bring it up, but you have to throw a party for Monica.

Closing Credits
[Scene: ATM vestibule, the power has come back on.]

Jill: Well, this has been fun.

Chandler: Yes. Yes, thanks for letting me use your phone... and for saving my life.

Jill: Well, goodbye Chandler. I had a great blackout. (she kisses him on the cheek) See ya.

(She leaves. Chandler presses his face to the glass door after her, stroking the window lovingly. He then turns to the security camera and starts talking to it.)

Chandler: Hi, um, I'm account number 7143457. And, uh, I don't know if you got any of that, but I would really like a copy of the tape.

End""", length=100))


Iteration 0, Loss: 6.212061882019043
Iteration 100, Loss: 4.883028030395508
Iteration 200, Loss: 4.875763893127441
Iteration 300, Loss: 4.875391960144043
Iteration 400, Loss: 4.875297546386719
Iteration 500, Loss: 4.875267505645752
Iteration 600, Loss: 4.875245571136475
Iteration 700, Loss: 4.875204086303711
Iteration 800, Loss: 4.875190258026123
Iteration 900, Loss: 4.875199317932129
Iteration 1000, Loss: 4.8751420974731445
Iteration 1100, Loss: 4.875117301940918
Iteration 1200, Loss: 4.875034332275391
Iteration 1300, Loss: 4.875131607055664
Iteration 1400, Loss: 4.875006198883057
Iteration 1500, Loss: 4.8750457763671875
Iteration 1600, Loss: 4.87501335144043
Iteration 1700, Loss: 4.874872207641602
Iteration 1800, Loss: 4.874929428100586
Iteration 1900, Loss: 4.874678611755371
Iteration 2000, Loss: 4.8746514320373535
Iteration 2100, Loss: 4.8742780685424805
Iteration 2200, Loss: 4.87418794631958
Iteration 2300, Loss: 4.873963356018066
Iteration 2400, Loss: 4.876402378082275
Iteration 

In [6]:
import jax.numpy as jnp
from jax import random, grad, jit
from flax import linen as nn
import optax


batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 1000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated


print(generate_text(params, model, start_token=0, length=100))


Iteration 0, Loss: 6.212061882019043
Iteration 100, Loss: 4.883028030395508
Iteration 200, Loss: 4.875763893127441
Iteration 300, Loss: 4.875391960144043
Iteration 400, Loss: 4.875297546386719
Iteration 500, Loss: 4.875267505645752
Iteration 600, Loss: 4.875245571136475
Iteration 700, Loss: 4.875204086303711
Iteration 800, Loss: 4.875190258026123
Iteration 900, Loss: 4.875199317932129
[0, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72]
