# Inference

## Imports and downloads

In [1]:
!pip install gdown



In [2]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from google.colab import drive
import os
import requests
import zipfile
import gdown
import ipywidgets as widgets
from IPython.display import display, Markdown

## Checkpoint mounting

In [3]:
file_id = '1RIDAda9iwBkUeqh1j4eetvRXyqA7hhnt'
url = f"https://drive.google.com/uc?id={file_id}"
output = "trained_model.zip"
model_dir = "model"

# Download the zipped file
gdown.download(url, output, quiet=False)

# Create a directory for the model and extract the contents
os.makedirs(model_dir, exist_ok=True)
with zipfile.ZipFile(output, "r") as zip_ref:
    zip_ref.extractall(model_dir)

if os.path.exists("/content/model/added_tokens.json"):
    os.remove("/content/model/added_tokens.json")

print("Model folder downloaded and extracted successfully!")

Downloading...
From (original): https://drive.google.com/uc?id=1RIDAda9iwBkUeqh1j4eetvRXyqA7hhnt
From (redirected): https://drive.google.com/uc?id=1RIDAda9iwBkUeqh1j4eetvRXyqA7hhnt&confirm=t&uuid=50ffbe7b-148d-46ae-b516-9178421b3a40
To: /content/trained_model.zip
100%|██████████| 822M/822M [00:04<00:00, 166MB/s]


Model folder downloaded and extracted successfully!


## Model

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [5]:
def load_model():
    tokenizer = T5Tokenizer.from_pretrained(model_dir)
    model = T5ForConditionalGeneration.from_pretrained(model_dir)
    model.to(device)
    model.eval()
    return tokenizer, model


tokenizer, model = load_model()

In [6]:
def format_code(unformatted_code, max_length=512):
    input_ids = tokenizer.encode(
        unformatted_code,
        return_tensors='pt',
        truncation=True,
        max_length=max_length,
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            num_beams=5,
            early_stopping=True,
        )

    formatted_code = tokenizer.decode(
        outputs[0],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )

    return formatted_code.strip()

In [8]:
title = widgets.HTML("<h2>Code Formatter</h2>")
description = widgets.HTML("<p>Enter unformatted code and receive formatted code as per learned conventions.</p>")
code_input = widgets.Textarea(
    value='',
    placeholder='Type your unformatted code here...',
    description='Unformatted Code:',
    layout=widgets.Layout(width='100%', height='200px')
)

format_button = widgets.Button(description="Format Code", button_style='success')
output_area = widgets.Output()

def on_format_button_click(b):
    with output_area:
        output_area.clear_output()
        unformatted_code = code_input.value.strip()
        if unformatted_code == "":
            display(Markdown("**Warning:** Please enter some code."))
        else:
            formatted_code = format_code(unformatted_code)
            display(Markdown("**Formatted Code:**"))
            display(Markdown(f"```python\n{formatted_code}\n```"))

format_button.on_click(on_format_button_click)

display(title, description, code_input, format_button, output_area)

HTML(value='<h2>Code Formatter</h2>')

HTML(value='<p>Enter unformatted code and receive formatted code as per learned conventions.</p>')

Textarea(value='', description='Unformatted Code:', layout=Layout(height='200px', width='100%'), placeholder='…

Button(button_style='success', description='Format Code', style=ButtonStyle())

Output()