<a href="https://colab.research.google.com/github/Fastpacer/Abstract_ART_Transformer/blob/main/Machine_Learning_Internship_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required dependencies
!pip install torch torchvision opencv-python pillow numpy tqdm gradio

# Import necessary libraries
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
import numpy as np
from tqdm import tqdm
import cv2
from PIL import Image
import os
import gradio as gr
import tempfile
from google.colab import files
from IPython.display import HTML, display
from pathlib import Path

# First Cell: Model Implementation
class StyleEncoder(nn.Module):
    def __init__(self):
        super(StyleEncoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU()
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, 128)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class TransformationGenerator(nn.Module):
    def __init__(self):
        super(TransformationGenerator, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.conv1 = nn.Conv2d(3 + 128, 64, kernel_size=7, stride=1, padding=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 3, kernel_size=7, stride=1, padding=3)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, image, style):
        batch_size, _, h, w = image.shape
        style = style.view(batch_size, -1, 1, 1).repeat(1, 1, h, w)
        x = torch.cat([image, style], dim=1)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return self.tanh(self.conv3(x))

class AbstractArtTransformer:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.style_encoder = StyleEncoder().to(device)
        self.generator = TransformationGenerator().to(device)
        print(f"Using device: {device}")

    def apply_brushstroke_effect(self, image):
        if isinstance(image, torch.Tensor):
            image = image.cpu().numpy().transpose(1, 2, 0)
        brushstroke = cv2.bilateralFilter(image, 9, 75, 75)
        return torch.tensor(brushstroke.transpose(2, 0, 1)).to(self.device)

    def apply_color_distortion(self, image, intensity=0.5):
        jitter = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)
        return jitter(image)

    def create_transformation_sequence(self, image, style_prompt, num_frames=30):
        frames = []
        style_vector = torch.randn(1, 128).to(self.device)

        for i in tqdm(range(num_frames), desc="Generating frames"):
            abstraction_level = i / num_frames
            distorted_image = self.apply_color_distortion(image, intensity=abstraction_level)
            brushstroke_image = self.apply_brushstroke_effect(distorted_image)

            with torch.no_grad():
                transformed_frame = self.generator(brushstroke_image.unsqueeze(0), style_vector)

            frames.append(transformed_frame.squeeze(0))

        return frames

    def create_video(self, frames, output_path, fps=24):
        frames = [frame.cpu().numpy().transpose(1, 2, 0) for frame in frames]
        frames = [(frame * 255).astype(np.uint8) for frame in frames]
        height, width, _ = frames[0].shape

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

        for frame in tqdm(frames, desc="Writing video"):
            video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

        video_writer.release()

# Second Cell: Gradio Interface Implementation
class ArtTransformerApp:
    def __init__(self):
        self.transformer = AbstractArtTransformer()
        self.available_styles = ["Cubism", "Surrealism", "Impressionism", "Abstract Expressionism",
                                "Pop Art", "Minimalism", "Art Nouveau", "Pointillism"]
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

    def process_image(self, input_image, style, abstraction_level, frame_count):
        if isinstance(input_image, str):
            input_image = Image.open(input_image)

        image_tensor = self.transform(input_image).to(self.transformer.device)

        frames = self.transformer.create_transformation_sequence(
            image_tensor,
            style,
            num_frames=frame_count
        )

        with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
            output_path = tmp_file.name

        self.transformer.create_video(frames, output_path)

        return output_path

# Third Cell: Create and Launch Gradio Interface
def create_gradio_interface():
    app = ArtTransformerApp()

    with gr.Blocks(title="Abstract Art Style Transformer") as interface:
        gr.Markdown("# Abstract Art Style Transformer")

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil", label="Input Image")
                style = gr.Dropdown(
                    choices=app.available_styles,
                    label="Art Style",
                    value="Cubism"
                )
                abstraction_level = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.5,
                    step=0.1,
                    label="Abstraction Level"
                )
                frame_count = gr.Slider(
                    minimum=30,
                    maximum=300,
                    value=90,
                    step=30,
                    label="Video Length (frames)"
                )
                generate_btn = gr.Button("Generate Abstract Art Video")

            with gr.Column():
                output_video = gr.Video(label="Generated Video")

        generate_btn.click(
            fn=app.process_image,
            inputs=[input_image, style, abstraction_level, frame_count],
            outputs=output_video
        )

    return interface

# Fourth Cell: Main Execution
def main():
    interface = create_gradio_interface()
    interface.launch(debug=True, share=True)

if __name__ == "__main__":
    main()

# Fifth Cell: Download Helper Function
def download_video(file_path):
    try:
        files.download(file_path)
    except:
        print(f"Video saved at: {file_path}")

Using device: cpu
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://89cd63236161877179.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Generating frames: 100%|██████████| 90/90 [02:14<00:00,  1.50s/it]
Writing video: 100%|██████████| 90/90 [00:00<00:00, 720.80it/s]
Generating frames:  36%|███▌      | 65/180 [01:33<03:05,  1.61s/it]