In [None]:
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install transformers
!pip install diffusers
!pip install wandb
!pip install torch

## Load finetuned model 

In [None]:
import wandb
import torch
import logging
from diffusers import StableDiffusionPipeline
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          DataCollatorForLanguageModeling, pipeline)

In [None]:
import logging

# Create a custom logger
logger = logging.getLogger(__name__)

# Create handlers
c_handler = logging.StreamHandler()
c_handler.setLevel(logging.INFO)

In [None]:
# variables
PROJECT_ID = 'jbarata1998/song-generator/model-baseline_gpt2_finetune:v0'
MODEL_ID_TEXT_GEN = "gpt2"
MODEL_ID_SUMMARIZE = ""

In [None]:
# download model artifact from wandb
run = wandb.init()
artifact = run.use_artifact(PROJECT_ID, type='model')
artifact_dir = artifact.download()

In [None]:
# init tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID_TEXT_GEN)

In [None]:
# save tokenizer to load pipeline
tokenizer.save_pretrained(artifact_dir)

## Generate Album cover

In [None]:
class CoverGenerator:
    def __init__(self, gen_model: str, summarize_model: str, diffuse_model: str):
        self.gen_model = gen_model
        self.summarize_model = summarize_model
        self.diffuse_model = diffuse_model
        
    def gen_text(self, prompt: str, **kwargs: dict):
        generator = pipeline('text-generation', model=self.gen_model)
        result = generator(prompt, top_k=kwargs.get("top_k", 5), max_new_tokens=kwargs.get("max_new_tokens", 400))
        self.song = result[0]["generated_text"]
        print(f"PROMPT: {prompt} \n\n SONG: \n\n {self.song}")
    
    def summarize_text(self, **kwargs: dict):
        try:
            summarizer = pipeline("summarization", self.summarize_model)
            summary = summarizer(self.song, min_length=kwargs.get("min_length", 5), max_length=kwargs.get("max_length", 75))
            self.song_summary = summary[0]["summary_text"]
            print(f" SUMMARY: \n\n {self.song_summary}")
        except Exception as e:
            print(f"Exception {e} occurred")
        
    def gen_cover(self, **kwargs: dict):
        pipe = StableDiffusionPipeline.from_pretrained(self.diffuse_model, torch_dtype=torch.float16)
        pipe = pipe.to("cuda")
        image = pipe(self.song_summary).images[0]
        return image

In [None]:
test_prompt = "Well, good for you, I guess you moved on really easily\nYou found a new girl and it only took a couple weeks\nRemember when you said that you wanted to give me the world?"

cover_generator = CoverGenerator(gen_model = artifact_dir, summarize_model = "facebook/bart-large-cnn", diffuse_model = "runwayml/stable-diffusion-v1-5" )

In [None]:
cover_generator.gen_text(prompt = test_prompt, top_k = 5, max_new_tokens = 400)

In [None]:
cover_generator.summarize_text(min_length= 5, max_length= 75, height=256, width=256)

In [None]:
image = cover_generator.gen_cover()
image