In [1]:
import json
import random
import math

def generate_manim_dataset():
    """Generate a comprehensive dataset of scene descriptions and corresponding Manim code"""

    dataset = []

    # 1. Basic Shape Animations (100 samples)
    shapes = ['Circle', 'Square', 'Triangle', 'Rectangle', 'Polygon']
    colors = ['RED', 'BLUE', 'GREEN', 'YELLOW', 'PURPLE', 'ORANGE', 'PINK', 'CYAN']
    animations = ['FadeIn', 'DrawBorderThenFill', 'Create', 'ShowCreation', 'Write']

    for i in range(10):
        shape = random.choice(shapes)
        color = random.choice(colors)
        animation = random.choice(animations)

        # Initialize description and code variables
        description = ""
        code = ""

        if shape == 'Circle':
            radius = round(random.uniform(0.5, 2.0), 1)
            description = f"A {color.lower()} circle with radius {radius} appears on screen"
            code = f"""from manim import *

class Scene{i+1}(Scene):
    def construct(self):
        circle = Circle(radius={radius}).set_color({color})
        self.play({animation}(circle))
        self.wait(1)"""

        elif shape == 'Square':
            side = round(random.uniform(1.0, 3.0), 1)
            description = f"A {color.lower()} square with side length {side} materializes"
            code = f"""from manim import *

class Scene{i+1}(Scene):
    def construct(self):
        square = Square(side_length={side}).set_color({color})
        self.play({animation}(square))
        self.wait(1)"""

        elif shape == 'Rectangle':
            width = round(random.uniform(2.0, 4.0), 1)
            height = round(random.uniform(1.0, 2.5), 1)
            description = f"A {color.lower()} rectangle {width}x{height} units draws itself"
            code = f"""from manim import *

class Scene{i+1}(Scene):
    def construct(self):
        rect = Rectangle(width={width}, height={height}).set_color({color})
        self.play({animation}(rect))
        self.wait(1)"""

        elif shape == 'Triangle':
            description = f"A {color.lower()} triangle appears with animation"
            code = f"""from manim import *

class Scene{i+1}(Scene):
    def construct(self):
        triangle = Triangle().set_color({color})
        self.play({animation}(triangle))
        self.wait(1)"""

        elif shape == 'Polygon':
            sides = random.choice([5, 6, 8])
            description = f"A {color.lower()} {sides}-sided polygon appears on screen"
            code = f"""from manim import *

class Scene{i+1}(Scene):
    def construct(self):
        polygon = RegularPolygon(n={sides}).set_color({color})
        self.play({animation}(polygon))
        self.wait(1)"""

        dataset.append({
            "id": i+1,
            "description": description,
            "duration": "1-2 seconds",
            "category": "basic_shapes",
            "manim_code": code
        })

    # 2. Text Animations (80 samples)
    text_samples = [
        "Hello World", "Mathematics", "Physics", "Chemistry", "Biology",
        "Welcome", "Python", "Manim", "Animation", "Science",
        "Learning", "Education", "Teaching", "Students", "Knowledge"
    ]

    for i in range(8):
        text = random.choice(text_samples)
        color = random.choice(colors)
        font_size = random.choice([24, 36, 48, 60])

        description = f"The text '{text}' appears in {color.lower()} color"
        code = f"""from manim import *

class Scene{i+101}(Scene):
    def construct(self):
        text = Text("{text}", font_size={font_size}).set_color({color})
        self.play(Write(text))
        self.wait(1)"""

        dataset.append({
            "id": i+101,
            "description": description,
            "duration": "1-2 seconds",
            "category": "text_animation",
            "manim_code": code
        })

    # 3. Mathematical Expressions (70 samples)
    expressions = [
        "x^2 + y^2 = r^2", "E = mc^2", "a^2 + b^2 = c^2",
        "\\frac{d}{dx}x^2 = 2x", "\\int x dx = \\frac{x^2}{2}",
        "\\sin^2(x) + \\cos^2(x) = 1", "F = ma", "PV = nRT"
    ]

    for i in range(7):
        expr = random.choice(expressions)
        color = random.choice(colors)

        description = f"Mathematical equation '{expr}' is written on screen"
        code = f"""from manim import *

class Scene{i+181}(Scene):
    def construct(self):
        equation = MathTex(r"{expr}").set_color({color})
        self.play(Write(equation))
        self.wait(1)"""

        dataset.append({
            "id": i+181,
            "description": description,
            "duration": "1-2 seconds",
            "category": "math_expressions",
            "manim_code": code
        })

    # 4. Movement Animations (60 samples)
    directions = ['LEFT', 'RIGHT', 'UP', 'DOWN']

    for i in range(600):
        shape = random.choice(['Circle', 'Square'])
        color = random.choice(colors)
        direction = random.choice(directions)
        distance = round(random.uniform(1.0, 3.0), 1)

        description = f"A {color.lower()} {shape.lower()} moves {direction.lower()} by {distance} units"

        if shape == 'Circle':
            code = f"""from manim import *

class Scene{i+251}(Scene):
    def construct(self):
        circle = Circle().set_color({color})
        self.play(Create(circle))
        self.play(circle.animate.shift({direction} * {distance}))
        self.wait(0.5)"""
        else:
            code = f"""from manim import *

class Scene{i+251}(Scene):
    def construct(self):
        square = Square().set_color({color})
        self.play(Create(square))
        self.play(square.animate.shift({direction} * {distance}))
        self.wait(0.5)"""

        dataset.append({
            "id": i+251,
            "description": description,
            "duration": "1-2 seconds",
            "category": "movement",
            "manim_code": code
        })

    # 5. Rotation Animations (50 samples)
    for i in range(5):
        shape = random.choice(['Square', 'Rectangle', 'Triangle'])
        color = random.choice(colors)
        angle = random.choice([45, 90, 180, 270, 360])

        description = f"A {color.lower()} {shape.lower()} rotates {angle} degrees"

        if shape == 'Square':
            code = f"""from manim import *

class Scene{i+311}(Scene):
    def construct(self):
        square = Square().set_color({color})
        self.play(Create(square))
        self.play(Rotate(square, {angle}*DEGREES))
        self.wait(0.5)"""
        elif shape == 'Rectangle':
            code = f"""from manim import *

class Scene{i+311}(Scene):
    def construct(self):
        rect = Rectangle().set_color({color})
        self.play(Create(rect))
        self.play(Rotate(rect, {angle}*DEGREES))
        self.wait(0.5)"""
        else:  # Triangle
            code = f"""from manim import *

class Scene{i+311}(Scene):
    def construct(self):
        triangle = Triangle().set_color({color})
        self.play(Create(triangle))
        self.play(Rotate(triangle, {angle}*DEGREES))
        self.wait(0.5)"""

        dataset.append({
            "id": i+311,
            "description": description,
            "duration": "1-2 seconds",
            "category": "rotation",
            "manim_code": code
        })

    # 6. Scaling Animations (40 samples)
    for i in range(4):
        shape = random.choice(['Circle', 'Square'])
        color = random.choice(colors)
        scale_factor = round(random.uniform(0.5, 2.5), 1)

        description = f"A {color.lower()} {shape.lower()} scales by factor {scale_factor}"

        if shape == 'Circle':
            code = f"""from manim import *

class Scene{i+361}(Scene):
    def construct(self):
        circle = Circle().set_color({color})
        self.play(Create(circle))
        self.play(circle.animate.scale({scale_factor}))
        self.wait(0.5)"""
        else:
            code = f"""from manim import *

class Scene{i+361}(Scene):
    def construct(self):
        square = Square().set_color({color})
        self.play(Create(square))
        self.play(square.animate.scale({scale_factor}))
        self.wait(0.5)"""

        dataset.append({
            "id": i+361,
            "description": description,
            "duration": "1-2 seconds",
            "category": "scaling",
            "manim_code": code
        })

    # 7. Graph and Function Plotting (50 samples)
    functions = [
        ("x^2", "x**2", "parabola"),
        ("sin(x)", "np.sin(x)", "sine wave"),
        ("cos(x)", "np.cos(x)", "cosine wave"),
        ("x^3", "x**3", "cubic function"),
        ("e^x", "np.exp(x)", "exponential function")
    ]

    for i in range(5):
        func_tex, func_code, func_name = random.choice(functions)
        color = random.choice(colors)

        description = f"A {color.lower()} graph of {func_name} {func_tex} appears"
        code = f"""from manim import *
import numpy as np

class Scene{i+401}(Scene):
    def construct(self):
        axes = Axes(x_range=[-3, 3], y_range=[-3, 3])
        graph = axes.plot(lambda x: {func_code}, color={color})
        self.play(Create(axes))
        self.play(Create(graph))
        self.wait(1)"""

        dataset.append({
            "id": i+401,
            "description": description,
            "duration": "2 seconds",
            "category": "graphs",
            "manim_code": code
        })

    # 8. Transformation Animations (50 samples)
    transformations = [
        ("Circle", "Square", "circle transforms into square"),
        ("Square", "Triangle", "square morphs into triangle"),
        ("Triangle", "Circle", "triangle becomes circle"),
        ("Rectangle", "Circle", "rectangle transforms to circle")
    ]

    for i in range(5):
        shape1, shape2, description_text = random.choice(transformations)
        color = random.choice(colors)

        description = f"A {color.lower()} {description_text}"
        code = f"""from manim import *

class Scene{i+451}(Scene):
    def construct(self):
        shape1 = {shape1}().set_color({color})
        shape2 = {shape2}().set_color({color})
        self.play(Create(shape1))
        self.play(Transform(shape1, shape2))
        self.wait(1)"""

        dataset.append({
            "id": i+451,
            "description": description,
            "duration": "2 seconds",
            "category": "transformations",
            "manim_code": code
        })

    return dataset

# Generate the dataset
dataset = generate_manim_dataset()

# Save as JSON file
with open('manim_scene_dataset.json', 'w', encoding='utf-8') as f:
    json.dump(dataset, f, indent=2, ensure_ascii=False)

print(f"Dataset generated with {len(dataset)} samples")
print("Categories distribution:")
categories = {}
for item in dataset:
    cat = item['category']
    categories[cat] = categories.get(cat, 0) + 1

for cat, count in categories.items():
    print(f"  {cat}: {count} samples")

print("\nDataset saved as 'manim_scene_dataset.json'")
print("Sample structure:")
print(json.dumps(dataset[0], indent=2))

Dataset generated with 644 samples
Categories distribution:
  basic_shapes: 10 samples
  text_animation: 8 samples
  math_expressions: 7 samples
  movement: 600 samples
  rotation: 5 samples
  scaling: 4 samples
  graphs: 5 samples
  transformations: 5 samples

Dataset saved as 'manim_scene_dataset.json'
Sample structure:
{
  "id": 1,
  "description": "A blue rectangle 2.4x1.3 units draws itself",
  "duration": "1-2 seconds",
  "category": "basic_shapes",
  "manim_code": "from manim import *\n\nclass Scene1(Scene):\n    def construct(self):\n        rect = Rectangle(width=2.4, height=1.3).set_color(BLUE)\n        self.play(ShowCreation(rect))\n        self.wait(1)"
}


In [2]:
import json
import random
import math
import torch
import tiktoken
from torch.utils.data import Dataset, DataLoader


class DataCreation():
    def __init__(self):
        self.DataList = json.load(open(r"manim_scene_dataset.json"))

        pass
    def converting_to_continuous_Text(self):
        data_list = self.DataList
        formatted_list = []
        for unformatted_data_item in data_list:

            formatted_list.append( "for visualization of the scene with description: " + unformatted_data_item['description']\
                + "and duration: "+ str(unformatted_data_item['duration']) + ", python and manim based code is the follwoing:\n" \
                    + unformatted_data_item['manim_code'])

        return formatted_list
    def verify_formatting(self):
        data_list = self.DataList
        unformatted_data_item = data_list[0]
        formatted_item = "for visualization of the scene with description: " + unformatted_data_item['description']\
                + "and duration: "+ str(unformatted_data_item['duration']) + ", python and manim based code is the follwoing:\n" \
                    + unformatted_data_item['manim_code']

        return formatted_item
    def save_formatted_data(self, formatted_data_list):
        json.dump(formatted_data_list, open(r"formatted_manim_scene_dataset.json", "w"))

class LMDataset(Dataset):
    def __init__(self, txt_list,tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride

        def tokenize_one_sample(txt):
            token_ids = self.tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
            for i in range(0, len(token_ids) - self.max_length, self.stride):
                input_chunk = token_ids[i:i + self.max_length]
                target_chunk = token_ids[i + 1: i + self.max_length + 1]
                self.input_ids.append(torch.tensor(input_chunk))
                self.target_ids.append(torch.tensor(target_chunk))

        for txt in txt_list:
            tokenize_one_sample(txt)
    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

def LMDataloader(txt_list,batch_size, max_length, stride, shuffle = True, drop_list = True, num_workers = 0):
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = LMDataset(txt_list,tokenizer, max_length, stride)

    dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, drop_last = drop_list, num_workers = num_workers)
    return dataloader

def finalDataLoader(batch_size, max_length, stride, train_ratio = 0.8, shuffle = True, drop_list = True, num_workers = 0):
    try:
        formatted_dataset = json.load(open(r"formatted_manim_scene_dataset.json"))
    except Exception as e:
        print(f"formatted_manim_scene_dataset.json not found with error:{e}")
        formatted_dataset = DataCreation().converting_to_continuous_Text()
    print("formatted sample size: ", len(formatted_dataset))
    DataCreation().save_formatted_data(formatted_dataset)


    print(formatted_dataset[0])
    print(len(formatted_dataset))
    split_idx = int(train_ratio * len(formatted_dataset))
    train_data = formatted_dataset[:split_idx]
    val_data = formatted_dataset[split_idx:]
    traindataloader = LMDataloader(
        txt_list=train_data,
        batch_size=batch_size,
        max_length=max_length,
        stride=stride
    )
    valdataloader = LMDataloader(
        txt_list=val_data,
        batch_size=batch_size,
        max_length=max_length,
        stride=stride
    )
    return traindataloader, valdataloader




if __name__ == "__main__":

    batch_size = 8
    max_length = 4
    stride = 1

    vocab_size = 50257
    output_dim = 256
    context_length = 1024

    token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
    pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

    dataloader,_ = finalDataLoader(batch_size, max_length, stride)
    print(type(dataloader))



    for batch in dataloader:
        x, y = batch

        token_embeddings = token_embedding_layer(x)
        pos_embeddings = pos_embedding_layer(torch.arange(max_length))

        input_embeddings = token_embeddings + pos_embeddings

        break


formatted_manim_scene_dataset.json not found with error:[Errno 2] No such file or directory: 'formatted_manim_scene_dataset.json'
formatted sample size:  644
for visualization of the scene with description: A blue rectangle 2.4x1.3 units draws itselfand duration: 1-2 seconds, python and manim based code is the follwoing:
from manim import *

class Scene1(Scene):
    def construct(self):
        rect = Rectangle(width=2.4, height=1.3).set_color(BLUE)
        self.play(ShowCreation(rect))
        self.wait(1)
644
<class 'torch.utils.data.dataloader.DataLoader'>


In [3]:
import torch
import torch.nn as nn
# from untokenizeddataformation import finalDataLoader

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out,d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        query = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = query.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)


        attn_scores = queries @ keys.transpose(2,3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1,2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

# if __name__ == "__main__":
#     torch.manual_seed(123)
#     batch_size = 8
#     max_length = 4
#     stride = 1

#     vocab_size = 50257
#     output_dim = 256
#     context_length = 1024

#     token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
#     pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

#     dataloader = finalDataLoader(batch_size, max_length, stride)

#     for batch in dataloader:
#         x, y = batch

#         token_embeddings = token_embedding_layer(x)
#         pos_embeddings = pos_embedding_layer(torch.arange(max_length))

#         input_embeddings = token_embeddings + pos_embeddings

#         break

#     context_length = max_length
#     d_in = output_dim
#     d_out = d_in

#     mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

#     batch = input_embeddings
#     context_vecs = mha(batch)

#     print("context_vecs.shape:", context_vecs.shape)

In [4]:
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# from multiheadattaention import *



class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return 0.5*x*(1+torch.tanh(
            torch.sqrt(torch.tensor(2.0/torch.pi)) * (x+0.044715 * torch.pow(x,3))
        ))


class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

class LayerNorm(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(embed_dim))
        self.shift = nn.Parameter(torch.zeros(embed_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim = True, unbiased = False)
        norm_x = (x-mean)/ torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift



class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])


    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back
        return x




class ManimGPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])


        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])


        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias=False
        )

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

# if __name__ == '__main__':
#     import tiktoken

#     Manim_GPT_CONFIG= {
#         "vocab_size": 50257,    # Vocabulary size
#         "context_length": 1024, # Context length
#         "emb_dim": 256,         # Embedding dimension
#         "n_heads": 4,          # Number of attention heads
#         "n_layers": 2,         # Number of layers
#         "drop_rate": 0.20,       # Dropout rate
#         "qkv_bias": False       # Query-Key-Value bias
#     }


#     tokenizer = tiktoken.get_encoding("gpt2")

#     batch = []

#     txt1 = "Every effort moves you"
#     txt2 = "Every day holds a"

#     batch.append(torch.tensor(tokenizer.encode(txt1)))
#     batch.append(torch.tensor(tokenizer.encode(txt2)))
#     batch = torch.stack(batch, dim=0)
#     print(batch)

#     model = ManimGPTModel(Manim_GPT_CONFIG)
#     logits = model(batch)
#     print(logits.shape)
#     total_params = sum(p.numel() for p in model.parameters())
#     print(f"Total number of parameters: {total_params:,}")

In [7]:
import torch
import torch.nn as nn
# from untokenizeddataformation import finalDataLoader
# from multiheadattaention import *
# from manimgpt import *
import tiktoken


Manim_GPT_CONFIG= {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 4,          # Number of attention heads
    "n_layers": 2,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}



def generate_text_simple(model, idx, max_new_tokens, context_size):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        print(f"step number:{_}")

        # Crop current context if it exceeds the supported context size
        # E.g., if LLM supports only 5 tokens, and the context size is 10
        # then only the last 5 tokens are used as context
        idx_cond = idx[:, -context_size:]

        # Get the predictions
        with torch.no_grad():
            logits = model(idx_cond)

        # Focus only on the last time step
        # (batch, n_token, vocab_size) becomes (batch, vocab_size)
        logits = logits[:, -1, :]

        # Get the idx of the vocab entry with the highest logits value
        idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch, 1)

        # Append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)

    return idx


def text_to_token_ids(text,tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
    return encoded_tensor
def token_ids_to_text(token_ids,tokenizer):
    flat = token_ids.squeeze(0)
    return tokenizer.decode(flat.tolist())


def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss


def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches = eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches = eval_iter)
    model.train()
    return train_loss, val_loss


def train_model(model, train_loader, val_loader, optimizer, device, num_epochs,
                eval_freq, eval_iter):
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1

    for epoch in range(num_epochs):
        model.train()

        for input_batch, target_batch in train_loader:
            optimizer.zero_grad(optimizer)
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()
            optimizer.step()
            tokens_seen += input_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, eval_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter
                )
                train_losses.append(train_loss)
                val_losses.append(eval_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {eval_loss:.3f}")
    return train_losses, val_losses, track_tokens_seen

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ManimGPTModel(Manim_GPT_CONFIG).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.004, weight_decay=0.1)
    batch_size = 20
    max_length = 100
    stride = 5
    train_loader, val_loader = finalDataLoader(batch_size, max_length, stride)

    num_epochs = 5
    print(len(train_loader))

    train_losses, val_losses, tokens_seen = train_model(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=5, eval_iter=5
    )


    torch.save(model.state_dict(), "manimgpt.pt")



formatted sample size:  644
for visualization of the scene with description: A blue rectangle 2.4x1.3 units draws itselfand duration: 1-2 seconds, python and manim based code is the follwoing:
from manim import *

class Scene1(Scene):
    def construct(self):
        rect = Rectangle(width=2.4, height=1.3).set_color(BLUE)
        self.play(ShowCreation(rect))
        self.wait(1)
644
200
Ep 1 (Step 000000): Train loss 5.695, Val loss 5.869
Ep 1 (Step 000005): Train loss 3.753, Val loss 4.049
Ep 1 (Step 000010): Train loss 3.223, Val loss 3.506
Ep 1 (Step 000015): Train loss 2.896, Val loss 3.236
Ep 1 (Step 000020): Train loss 2.549, Val loss 2.971
Ep 1 (Step 000025): Train loss 2.216, Val loss 2.694
Ep 1 (Step 000030): Train loss 1.897, Val loss 2.736
Ep 1 (Step 000035): Train loss 1.590, Val loss 2.163
Ep 1 (Step 000040): Train loss 1.378, Val loss 1.981
Ep 1 (Step 000045): Train loss 1.262, Val loss 1.819
Ep 1 (Step 000050): Train loss 1.025, Val loss 1.524
Ep 1 (Step 000055): Train 

In [None]:
print("checking inference")
model.to("cpu")
model.eval()

tokenizer = tiktoken.get_encoding("gpt2")

token_ids = generate_text_simple(
    model=model,
    idx=text_to_token_ids("for visualization of the scene with description: A purple square morphs into triangleand duration: 2 seconds, python and manim based code is the follwoing:", tokenizer),
    max_new_tokens=1000,
    context_size=Manim_GPT_CONFIG["context_length"]
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))



checking inference
