In [None]:
# optional setup; use if the notebook is not running inside the rtfm conda environment
!git clone https://github.com/mlfoundations/rtfm.git
%cd rtfm

# Ensure pip is up to date
!pip install --upgrade pip

# Install Python 3.8 using pip
!pip install python==3.8

# Install pip dependencies from requirements.txt
!pip install -r requirements.txt

# Install additional dependencies
!pip install git+https://github.com/jpgard/llama-recipes.git
!pip install -e .
!pip install --no-deps git+https://github.com/mlfoundations/tableshift.git

In [3]:
!pip install tablib

Collecting tablib
  Downloading tablib-3.7.0-py3-none-any.whl.metadata (3.8 kB)
Downloading tablib-3.7.0-py3-none-any.whl (47 kB)
Installing collected packages: tablib
Successfully installed tablib-3.7.0


In [4]:
import sys
sys.path.append('/content/rtfm')

In [None]:
!pip install transformers accelerate

In [14]:
import pandas as pd
import torch
from transformers import AutoTokenizer, LlamaForCausalLM, AutoConfig

# Define TrainConfig and TokenizerConfig manually
class TrainConfig:
    def __init__(self, model_name, context_length=8192, serializer_cls=None):
        self.model_name = model_name
        self.context_length = context_length
        self.serializer_cls = serializer_cls

class TokenizerConfig:
    def __init__(self, use_fast_tokenizer=True, add_serializer_tokens=False, serializer_tokens_embed_fn=None):
        self.use_fast_tokenizer = use_fast_tokenizer
        self.add_serializer_tokens = add_serializer_tokens
        self.serializer_tokens_embed_fn = serializer_tokens_embed_fn

# Initialize configuration
train_config = TrainConfig(model_name="mlfoundations/tabula-8b", context_length=8192)
tokenizer_config = TokenizerConfig()

# Load the configuration
config = AutoConfig.from_pretrained(train_config.model_name)
config.torch_dtype = torch.bfloat16  # Match TabuLa train/eval setup

# Ensure device map is correctly set up
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model with `accelerate` managing the device map
model = LlamaForCausalLM.from_pretrained(
    train_config.model_name,
    device_map="auto",  # Automatically manages device placement
    offload_folder="./offload",  # Directory to offload if needed
    config=config
)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name)

# Dummy serializer implementation
class DummySerializer:
    def __init__(self):
        self.special_tokens = ["<|start|>", "<|end|>"]

serializer = DummySerializer()

# Prepare tokenizer
def prepare_tokenizer(model, tokenizer, pretrained_model_name_or_path, model_max_length, use_fast_tokenizer, serializer_tokens_embed_fn=None, serializer_tokens=None):
    if serializer_tokens:
        tokenizer.add_special_tokens({"additional_special_tokens": serializer_tokens})
        model.resize_token_embeddings(len(tokenizer))
    tokenizer.model_max_length = model_max_length
    return tokenizer, model

tokenizer, model = prepare_tokenizer(
    model,
    tokenizer=tokenizer,
    pretrained_model_name_or_path=train_config.model_name,
    model_max_length=train_config.context_length,
    use_fast_tokenizer=tokenizer_config.use_fast_tokenizer,
    serializer_tokens_embed_fn=tokenizer_config.serializer_tokens_embed_fn,
    serializer_tokens=serializer.special_tokens
    if tokenizer_config.add_serializer_tokens
    else None,
)

# Set proper pad token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [17]:
import pandas as pd

labeled_anime = pd.DataFrame(
    [
        {"title": "Naruto", "genre": "Action", "episodes": 220, "rating": 7.9, "popularity": 95, "year": 2002,
         "studio": "Pierrot", "source_material": "Manga", "completed_episodes": 220, "recommendation": "Yes"},
        {"title": "Attack on Titan", "genre": "Action", "episodes": 87, "rating": 9.1, "popularity": 98, "year": 2013,
         "studio": "Wit Studio", "source_material": "Manga", "completed_episodes": 87, "recommendation": "Yes"},
        {"title": "Death Note", "genre": "Mystery", "episodes": 37, "rating": 8.6, "popularity": 96, "year": 2006,
         "studio": "Madhouse", "source_material": "Manga", "completed_episodes": 37, "recommendation": "Yes"},
        {"title": "One Piece", "genre": "Adventure", "episodes": 1071, "rating": 8.7, "popularity": 97, "year": 1999,
         "studio": "Toei Animation", "source_material": "Manga", "completed_episodes": 100, "recommendation": "Yes"},
        {"title": "My Hero Academia", "genre": "Action", "episodes": 138, "rating": 8.0, "popularity": 94, "year": 2016,
         "studio": "Bones", "source_material": "Manga", "completed_episodes": 120, "recommendation": "Yes"},
        {"title": "Tokyo Ghoul", "genre": "Horror", "episodes": 48, "rating": 7.6, "popularity": 88, "year": 2014,
         "studio": "Pierrot", "source_material": "Manga", "completed_episodes": 48, "recommendation": "No"},
        {"title": "Sword Art Online", "genre": "Fantasy", "episodes": 96, "rating": 7.5, "popularity": 92, "year": 2012,
         "studio": "A-1 Pictures", "source_material": "Light Novel", "completed_episodes": 96, "recommendation": "No"},
        {"title": "Demon Slayer", "genre": "Action", "episodes": 44, "rating": 8.7, "popularity": 99, "year": 2019,
         "studio": "Ufotable", "source_material": "Manga", "completed_episodes": 44, "recommendation": "Yes"},
        {"title": "Black Clover", "genre": "Fantasy", "episodes": 170, "rating": 8.0, "popularity": 90, "year": 2017,
         "studio": "Pierrot", "source_material": "Manga", "completed_episodes": 100, "recommendation": "No"},
    ]
)

target_anime = pd.DataFrame(
    [
        {"title": "Jujutsu Kaisen", "genre": "Action", "episodes": 24, "rating": 8.9, "popularity": 98, "year": 2020,
         "studio": "MAPPA", "source_material": "Manga", "completed_episodes": 24},
    ]
)

output = model.predict(
    target_example=target_anime,
    target_colname="recommendation",
    target_choices=["Yes", "No"],
    labeled_examples=labeled_anime,
)
print(f"Prediction for sample \n {target_anime} \n is: {output}")


Prediction for sample
            title   genre  episodes  rating  popularity  year studio  \
0  Jujutsu Kaisen  Action        24     8.9          98  2020  MAPPA   

  source_material  completed_episodes  
0           Manga                  24  
is: Yes
