In [1]:
# Required packages, install if not installed (assume PyTorch* and Intel® Extension for PyTorch* is already present)
!echo "Installation in progress..."
# import sys
# !{sys.executable} -m pip install  invisible-watermark > /dev/null
# !conda install -y --quiet --prefix {sys.prefix}  -c conda-forge \
#     accelerate==0.23.0 \
#     validators==0.22.0 \
#     diffusers==0.18.2 \
#     transformers==4.32.1 \
#     tensorboardX \
#     pillow \
#     ipywidgets \
#     ipython > /dev/null && echo "Installation successful" || echo "Installation failed"
import sys
!{sys.executable} -m pip install invisible-watermark --user > /dev/null 2>&1 
!{sys.executable} -m pip install transformers huggingface-hub --user > /dev/null 2>&1
!echo "Installtion complete..."

Installation in progress...
Installtion complete...


In [54]:
from io import BytesIO
import os
import time
import warnings
from pathlib import Path
from typing import List, Dict, Tuple


# Suppress warnings for a cleaner output.
warnings.filterwarnings("ignore")

import random
import requests
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex  # adds xpu namespace to PyTorch, enabling you to use Intel GPUs
import validators
import numpy as np

from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import DPMSolverMultistepScheduler

import os
import random
import time

import ipywidgets as widgets
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mp_img
import validators

from IPython.display import clear_output
from IPython.display import display
from IPython.display import HTML
from IPython.display import Image as IPImage
from ipywidgets import VBox, HBox

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models, transforms
from transformers import AdamW
from torch.nn import CrossEntropyLoss
from PIL import Image
import numpy as np


In [57]:
class Img2ImgModel:
    """
    This class creates a model for transforming images based on given prompts.
    """

    def __init__(
        self,
        model_id_or_path: str,
        device: str = "xpu",
        torch_dtype: torch.dtype = torch.bfloat16,
        optimize: bool = True,
        warmup: bool = False,
        scheduler: bool = True,
    ) -> None:
        """
        Initialize the model with the specified parameters.

        Args:
            model_id_or_path (str): The ID or path of the pre-trained model.
            device (str, optional): The device to run the model on. Defaults to "xpu".
            torch_dtype (torch.dtype, optional): The data type to use for the model. Defaults to torch.float16.
            optimize (bool, optional): Whether to optimize the model. Defaults to True.
        """
        self.device = device
        self.data_type = torch_dtype
        self.scheduler = scheduler
        self.generator = torch.Generator()  # .manual_seed(99)
        self.pipeline = self._load_pipeline(model_id_or_path, torch_dtype)
        if optimize:
            start_time = time.time()
            #print("Optimizing the model...")
            self.optimize_pipeline()
            #print(
            #    "Optimization completed in {:.2f} seconds.".format(
            #        time.time() - start_time
            #    )
            #)
        if warmup:
            self.warmup_model()

    def _load_pipeline(
        self, model_id_or_path: str, torch_dtype: torch.dtype
    ) -> StableDiffusionImg2ImgPipeline:
        """
        Load the pipeline for the model.

        Args:
            model_id_or_path (str): The ID or path of the pre-trained model.
            torch_dtype (torch.dtype): The data type to use for the model.

        Returns:
            StableDiffusionImg2ImgPipeline: The loaded pipeline.
        """
        print("Loading the model...")
        model_path = Path(f"/home/common/data/Big_Data/GenAI/{model_id_or_path}")
        
        if model_path.exists():
            #print(f"Loading the model from {model_path}...")
            load_path = model_path
        else:
            print("Using the default path for models...")
            load_path = model_id_or_path
            
        pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
            load_path,
            torch_dtype=torch_dtype,
            use_safetensors=True,
            variant="fp16",
        )
        if self.scheduler:
            pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
                pipeline.scheduler.config
            )
        if not model_path.exists():
            try:
                print(f"Attempting to save the model to {model_path}...")
                pipeline.save_pretrained(f"{model_path}")
                print("Model saved.")
            except Exception as e:
                print(f"An error occurred while saving the model: {e}. Proceeding without saving.")
        pipeline = pipeline.to(self.device)
        #print("Model loaded.")
        return pipeline

    def _optimize_pipeline(
        self, pipeline: StableDiffusionImg2ImgPipeline
    ) -> StableDiffusionImg2ImgPipeline:
        """
        Optimize the pipeline of the model.

        Args:
            pipeline (StableDiffusionImg2ImgPipeline): The pipeline to optimize.

        Returns:
            StableDiffusionImg2ImgPipeline: The optimized pipeline.
        """
        for attr in dir(pipeline):
            if isinstance(getattr(pipeline, attr), nn.Module):
                setattr(
                    pipeline,
                    attr,
                    ipex.optimize(
                        getattr(pipeline, attr).eval(),
                        dtype=pipeline.text_encoder.dtype,
                        inplace=True,
                    ),
                )
        return pipeline

    def optimize_pipeline(self) -> None:
        """
        Optimize the pipeline of the model.
        """
        self.pipeline = self._optimize_pipeline(self.pipeline)

    def get_image_from_url(self, url: str, path: str) -> Image.Image:
        """
        Get an image from a URL or from a local path if it exists.

        Args:
            url (str): The URL of the image.
            path (str): The local path of the image.

        Returns:
            Image.Image: The loaded image.
        """
        response = requests.get(url)
        if response.status_code != 200:
            raise Exception(
                f"Failed to download image. Status code: {response.status_code}"
            )
        if not response.headers["content-type"].startswith("image"):
            raise Exception(
                f"URL does not point to an image. Content type: {response.headers['content-type']}"
            )
        img = Image.open(BytesIO(response.content)).convert("RGB")
        img.save(path)
        img = img.resize((768, 512))
        return img

    def warmup_model(self):
        """
        Warms up the model by generating a sample image.
        """
        print("Setting up model...")
        start_time = time.time()
        image_url = "https://user-images.githubusercontent.com/786476/256401499-f010e3f8-6f8d-4e9f-9d1f-178d3571e7b9.png"
        try:
            self.generate_images(
                image_url=image_url,
                prompt="A beautiful day",
                num_images=1,
                save_path=".tmp",
            )
        except Exception:
            print("model warmup delayed...")
        #print(
        #    "Model is set up and ready! Warm-up completed in {:.2f} seconds.".format(
        #        time.time() - start_time
        #    )
        #)

    def get_inputs(self, prompt, batch_size=1):
        self.generator = [torch.Generator() for i in range(batch_size)]
        prompts = batch_size * [prompt]
        return {"prompt": prompts, "generator": self.generator}

    def generate_images(
        self,
        prompt: str,
        image_url: str,
        num_images: int = 5,
        num_inference_steps: int = 30,
        strength: float = 0.75,
        guidance_scale: float = 7.5,
        save_path: str = "image_to_image",
        batch_size: int = 1,
    ):
        """
        Generate images based on the provided prompt and variations.

        Args:
            prompt (str): The base prompt for the generation.
            image_url (str): The URL of the seed image.
            variations (List[str]): The list of variations to apply to the prompt.
            num_images (int, optional): The number of images to generate. Defaults to 5.
            num_inference_steps (int, optional): Number of noise removal steps.
            strength (float, optional): The strength of the transformation. Defaults to 0.75.
            guidance_scale (float, optional): The scale of the guidance. Defaults to 7.5.
            save_path (str, optional): The path to save the generated images. Defaults to "image_to_image".

        """
        input_image_path = "input.png"
        init_image = self.get_image_from_url(image_url, input_image_path)
        init_images = [init_image for _ in range(batch_size)]
        for i in range(0, num_images, batch_size):
            with torch.xpu.amp.autocast(
                enabled=True if self.data_type != torch.float32 else False,
                dtype=self.data_type,
            ):
                if batch_size > 1:
                    inputs = self.get_inputs(batch_size=batch_size, prompt=prompt)
                    images = self.pipeline(
                        **inputs,
                        image=init_images,
                        strength=strength,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_inference_steps,
                    ).images
                else:
                    images = self.pipeline(
                        prompt=prompt,
                        image=init_images,
                        strength=strength,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_inference_steps,
                    ).images

                for j in range(len(images)):
                    output_image_path = os.path.join(
                        save_path,
                        f"{'_'.join(prompt.split()[:3])}_{i+j}__{int(time.time() * 1e6)}.png",
                    )
                    images[j].save(output_image_path)
                    
    def generate_image(
        self,
        prompt: str,
        image: Image.Image,
        num_images: int = 5,
        num_inference_steps: int = 30,
        strength: float = 0.75,
        guidance_scale: float = 7.5,
        save_path: str = "image_to_image",
        batch_size: int = 1,
    ):
        """
        Generate images based on the provided prompt and variations.

        Args:
            prompt (str): The base prompt for the generation.
            image_url (str): The URL of the seed image.
            variations (List[str]): The list of variations to apply to the prompt.
            num_images (int, optional): The number of images to generate. Defaults to 5.
            num_inference_steps (int, optional): Number of noise removal steps.
            strength (float, optional): The strength of the transformation. Defaults to 0.75.
            guidance_scale (float, optional): The scale of the guidance. Defaults to 7.5.
            save_path (str, optional): The path to save the generated images. Defaults to "image_to_image".

        """
        init_image = image
        init_images = [init_image for _ in range(batch_size)]
        out = []
        for i in range(0, num_images, batch_size):
            with torch.xpu.amp.autocast(
                enabled=True if self.data_type != torch.float32 else False,
                dtype=self.data_type,
            ):
                if batch_size > 1:
                    inputs = self.get_inputs(batch_size=batch_size, prompt=prompt)
                    images = self.pipeline(
                        **inputs,
                        image=init_images,
                        strength=strength,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_inference_steps,
                    ).images
                else:
                    images = self.pipeline(
                        prompt=prompt,
                        image=init_images,
                        strength=strength,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_inference_steps,
                    ).images

                for j in range(len(images)):
                    output_image_path = os.path.join(
                        save_path,
                        f"{'_'.join(prompt.split()[:3])}_{i+j}__{int(time.time() * 1e6)}.png",
                    )
                    out.append(images[j])
        return out



model_cache = {}

def image_to_image():
    out = widgets.Output()
    image_to_image_dir = "image_to_image"
    num_images = 2
    model_ids = [
        "runwayml/stable-diffusion-v1-5",
        "stabilityai/stable-diffusion-2-1",
    ]    
    model_dropdown = widgets.Dropdown(
        options=model_ids,
        value=model_ids[0],
        description="Model:",
    )    
    prompt_text = widgets.Text(
        value="",
        placeholder="Enter your prompt",
        description="Prompt:",
    )    
    num_images_slider = widgets.IntSlider(
        value=2,
        min=1,
        max=10,
        step=1,
        description="Images:",
    )    
    image_url_text = widgets.Text(
        value="https://user-images.githubusercontent.com/786476/256401499-f010e3f8-6f8d-4e9f-9d1f-178d3571e7b9.png",
        placeholder="Enter an image URL",
        description="Image URL:",
    )
    enhance_checkbox = widgets.Checkbox(
        value=False,
        description="Auto enhance the prompt?",
        disabled=False,
        indent=False
    )
    enhance_checkbox.layout.margin = "0 0 0 10px"
    num_images_slider.layout.margin = "0 0 0 8px"
    prompt_text.layout.width = "100%"
    layout = widgets.Layout(margin="0px 50px 10px 0px")
    button = widgets.Button(description="Generate Images!", button_style="primary")
    left_box = VBox([model_dropdown,num_images_slider], layout=layout)
    right_box = VBox([image_url_text, enhance_checkbox], layout=layout)
    user_input_widgets = HBox([left_box, right_box], layout=layout)
    prompt_text.layout.width = "57.5%"
    button.layout.margin = "35px"
    display(user_input_widgets)
    display( prompt_text)
    display(button)
    display(out)
    
    
    def on_submit(button):
        with out:
            clear_output(wait=True)
            print("Once generated, images will be saved to `./image_to_image` dir, please wait...")
            selected_model_index = model_ids.index(model_dropdown.value)
            model_id = model_ids[selected_model_index]
            model_key = (model_id, "xpu")
            prompt = prompt_text.value
            num_images = num_images_slider.value
            image_url = image_url_text.value
            
            if not validators.url(image_url):
                print("The input is not a valid URL. Using the default URL instead.")
                image_url = "https://user-images.githubusercontent.com/786476/256401499-f010e3f8-6f8d-4e9f-9d1f-178d3571e7b9.png"       
            #model = Img2ImgModel(model_id, device="xpu")
            if model_key not in model_cache:
                model_cache[model_key] = Img2ImgModel(model_id, device="xpu")
            model = model_cache[model_key]
            enhancements = [
            "purple light",
            "dreaming",
            "cyberpunk",
            "ancient" ", rustic",
            "gothic",
            "historical",
            "punchy",
            "photo" "vivid colors",
            "4k",
            "bright",
            "exquisite",
            "painting",
            "art",
            "fantasy [,/organic]",
            "detailed",
            "trending in artstation fantasy",
            "electric",
            "night",
            ]
            if not prompt:
                prompt = " "
            if enhance_checkbox.value:
                prompt = prompt + " " + " ".join(random.sample(enhancements, 5))
                print(f"Using enhanced prompt: {prompt}")    
            try:
                start_time = time.time()
                os.makedirs(image_to_image_dir, exist_ok=True)
                model.generate_images(
                    prompt=prompt,
                    image_url=image_url,
                    num_images=num_images,
                )
                clear_output(wait=True)
                display_generated_images()
            except KeyboardInterrupt:
                print("\nUser interrupted image generation...")
            except Exception as e:
                print(f"An error occurred: {e}")
            finally:
                status = f"Complete generating {num_images} images in {time.time() - start_time:.2f} seconds."
                #print(status)
    button.on_click(on_submit)

def display_generated_images(image_to_image_dir="image_to_image"):
    image_files = [f for f in os.listdir(image_to_image_dir) if f.endswith((".png", ".jpg"))]    
    num_images = len(image_files)
    num_columns = int(np.ceil(np.sqrt(num_images)))
    num_rows = int(np.ceil(num_images / num_columns))
    fig, axs = plt.subplots(num_rows, num_columns, figsize=(10 * num_columns / num_columns, 10 * num_rows / num_rows))
    if num_images == 1:
        axs = np.array([[axs]])
    elif num_columns == 1 or num_rows == 1:
        axs = np.array([axs])
    for ax, image_file in zip(axs.ravel(), image_files):
        img = mp_img.imread(os.path.join(image_to_image_dir, image_file))
        ax.imshow(img)
        ax.axis("off")  # Hide axes
    for ax in axs.ravel()[num_images:]:
        ax.axis("off")
    plt.tight_layout()
    print(f"\nGenerated images...:")
    plt.show()

In [58]:
BASE_PROMPT = "Image of a <INPUT>"
model_id = "stabilityai/stable-diffusion-2-1"
model = Img2ImgModel(model_id, device="xpu")

def generate_additional_images(image, className, n=1):
    prompt = BASE_PROMPT.replace("<INPUT>", className)
    return model.generate_image(prompt, image, n)

    
url = r"https://firebasestorage.googleapis.com/v0/b/hacklytics-fa91a.appspot.com/o/2024-02-10T23%3A17%3A55.324Z?alt=media&token=6143c912-d0e4-4c7f-b882-2391adb653a1."
img = model.get_image_from_url(url, "img.png")
output = generate_additional_images(img, "Sattelite Shot of Wildfire", 10)

Loading the model...


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

In [61]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, TensorDataset
from PIL import Image
import numpy as np

class VGG16Settings():
    NUM_PREDICTION_CLASSES = None
    TRAIN_EPOCHS = 10
    LR = .01
    BATCH_SIZE = 32
    DEVICE = "xpu"

class BaseModel():
    def __init__(settings):
        pass
    def train_model():
        pass
    def classify():
        pass
        
class VGG16():
    def __init__(self, settings : VGG16Settings):
        # Load the pretrained VGG16 model
        vgg16 = models.vgg16(pretrained=True)

        # Freeze the convolutional layers
        for param in vgg16.features.parameters():
            param.requires_grad = False

        # Replace the classifier
        num_features = vgg16.classifier[0].in_features
        vgg16.classifier = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, settings.NUM_PREDICTION_CLASSES),
            nn.LogSoftmax(dim=1)
        )
        self.settings = settings
        self.model = vgg16

    def train_model(self, inputs, targets):
        
        #BOB's suggestion
        train_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        # Assuming inputs are a list of PIL Images and targets are a list of labels
        inputs_transformed = [train_transforms(input).unsqueeze(0) for input in inputs]
        inputs_tensor = torch.cat(inputs_transformed, dim=0)
        targets_tensor = torch.tensor(targets, dtype=torch.long)

        # Create a Dataset and DataLoader for batch processing
        dataset = TensorDataset(inputs_tensor, targets_tensor)
        train_loader = DataLoader(dataset, batch_size=self.settings.BATCH_SIZE, shuffle=True)

        # Training setup
        self.model.to(self.settings.DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(self.model.parameters(), lr=self.settings.LR)

        # Training loop
        self.model.train()
        losses = []
        for epoch in range(self.settings.TRAIN_EPOCHS):
            total_loss = 0
            for batch_inputs, batch_targets in train_loader:
                batch_inputs, batch_targets = batch_inputs.to(self.settings.DEVICE), batch_targets.to(self.settings.DEVICE)

                optimizer.zero_grad()
                outputs = self.model(batch_inputs)
                loss = criterion(outputs, batch_targets)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss/len(train_loader)
            print(avg_loss)
            losses.append(avg_loss)
        return losses
    
    def classify(self, image):
        # Define the image transformation
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        # Load and preprocess the image
        img_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension

        img_tensor = img_tensor.to(self.settings.DEVICE)
        self.model.to(self.settings.DEVICE)

        # Set the model to evaluation mode and predict
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(img_tensor)
        _, predicted = torch.max(outputs, 1)

        return predicted.item()

In [64]:
import torch
import numpy as np
from PIL import Image

# Assuming the VGG16, VGG16Settings, and BaseModel classes are defined as provided

# Step 1: Define a test case function
def test_vgg16():
    # Setup
    settings = VGG16Settings()
    settings.NUM_PREDICTION_CLASSES = 10  # Example: 10 classes for classification
    settings.DEVICE = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"  # MKL makes this faster than XPU
    
    # Instantiate the VGG16 model
    model = VGG16(settings=settings)
    
    # Generate a dummy dataset
    # For training: Create 100 dummy images (3x224x224) and labels
    num_samples = 100
    dummy_images = [Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255)) for _ in range(num_samples)]
    dummy_labels = np.random.randint(0, settings.NUM_PREDICTION_CLASSES, num_samples)
    
    # Train the model with the dummy dataset
    model.train_model(dummy_images, dummy_labels)
    
    # Generate a single dummy image for classification
    test_image = Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255))
    
    # Classify the image
    prediction = model.classify(test_image)
    assert isinstance(prediction, int), "The classification method should return an integer"
    assert 0 <= prediction < settings.NUM_PREDICTION_CLASSES, "The prediction should be within the range of possible classes"
    print("Test passed: VGG16 model instantiation, training, and classification.")

# Step 2: Run the test
test_vgg16()

14.32515150308609
17.293466091156006
5.30634880065918
3.660771131515503
1.883251503109932
1.6578410863876343
1.5726151764392853
1.4166490137577057
1.3383285403251648
1.2096823304891586
Test passed: VGG16 model instantiation, training, and classification.


In [98]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import numpy as np

class ViTSettings():
    NUM_PREDICTION_CLASSES = 1000
    TRAIN_EPOCHS = 10
    LR = .01
    BATCH_SIZE = 32
    DEVICE = "xpu"

class ViTModel():
    def __init__(self, settings : ViTSettings):
        # Load the pretrained ViT model and feature extractor
        self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
        self.model.classifier = nn.Linear(self.model.config.hidden_size, settings.NUM_PREDICTION_CLASSES)
        self.model = ipex.optimize(self.model)
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
        self.settings = settings

    def train_model(self, inputs, targets):
        # Define the image transformations
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
        # Apply transformations to each image in inputs
        inputs_transformed = [transform(input_image) for input_image in inputs]
        inputs_tensor = torch.stack(inputs_transformed)
        targets_tensor = torch.tensor(targets, dtype=torch.long)

        # Create a Dataset and DataLoader for batch processing
        dataset = TensorDataset(inputs_tensor, targets_tensor)
        train_loader = DataLoader(dataset, batch_size=self.settings.BATCH_SIZE, shuffle=True)

        # Training setup
        self.model.to(self.settings.DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(self.model.parameters(), lr=self.settings.LR)

        # Training loop
        self.model.train()
        losses = []
        for epoch in range(self.settings.TRAIN_EPOCHS):
            total_loss = 0
            for batch in train_loader:
                batch_inputs, batch_targets = batch[0].to(self.settings.DEVICE), batch[1].to(self.settings.DEVICE)

                optimizer.zero_grad()

                # Forward pass
                outputs = self.model(pixel_values=batch_inputs).logits

                loss = criterion(outputs, batch_targets)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch+1}, Average Loss: {avg_loss}")
            losses.append(avg_loss)

        return losses

    def classify(self, image):
        # Define the image transformation
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        # Load and preprocess the image
        img_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension
    
        img_tensor = img_tensor.to(self.settings.DEVICE)
        self.model.to(self.settings.DEVICE)
    
        # Set the model to evaluation mode and predict
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(pixel_values=img_tensor).logits
        predicted_class_idx = outputs.argmax(-1).item()
    
        return predicted_class_idx


In [99]:
import torch
import numpy as np
from PIL import Image

# Assuming the CustomViTModel and CustomViTSettings classes are defined as provided

# Step 1: Define a test case function
def test_vit_model():
    # Setup
    settings = ViTSettings()
    settings.NUM_PREDICTION_CLASSES = 10  # Example: Adjust based on your actual classes
    settings.DEVICE = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"
    
    # Instantiate the ViT model
    model = ViTModel(settings=settings)
    
    # Generate a dummy dataset
    # For training: Create dummy images (3x224x224) and labels
    num_samples = 100
    dummy_images = [Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255)) for _ in range(num_samples)]
    dummy_labels = np.random.randint(0, settings.NUM_PREDICTION_CLASSES, num_samples)
    
    # Train the model with the dummy dataset
    losses = model.train_model(dummy_images, dummy_labels)
    assert losses, "Training should produce a list of losses"
    
    # Generate a single dummy image for classification
    test_image = Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255))
    
    # Classify the image
    predicted_class_idx = model.classify(test_image)
    print(f"Predicted class index: {predicted_class_idx}")
    assert isinstance(predicted_class_idx, int), "The classification method should return an integer"
    assert 0 <= predicted_class_idx < settings.NUM_PREDICTION_CLASSES, "The predicted class index should be within the range of possible classes"
    
    print("Test passed: CustomViTModel model instantiation, training, and classification.")

# Step 2: Run the test
test_vit_model()

Epoch 1, Average Loss: 2.553674876689911
Epoch 2, Average Loss: 2.4934158325195312
Epoch 3, Average Loss: 2.6183581948280334
Epoch 4, Average Loss: 2.5135324597358704
Epoch 5, Average Loss: 2.5134425163269043
Epoch 6, Average Loss: 2.506541073322296
Epoch 7, Average Loss: 2.544667661190033
Epoch 8, Average Loss: 2.463534951210022
Epoch 9, Average Loss: 2.607055187225342
Epoch 10, Average Loss: 2.5501466393470764
Predicted class index: 9
Test passed: CustomViTModel model instantiation, training, and classification.


In [None]:
class ServerOperationHandler():
    def __init__(self):
        self.CONNECTOR = None
        self.models = dict()
        self.url = "http://18.188.69.104:5000"

    def main_loop(self):
        while(True):
            self.get_compute_job()
            self.get_classify_jobs()
            time.sleep(0.5)
        pass # Keep calling this to get jobs and process them

    def get_compute_job(self):
        response = requests.get(self.url + "/checkModels")
        if (response.status_code is 200):
            data = response.json()
            imageData = data["data"]["imageData"]
            images = [model.get_image_from_url(image["imageUrl"]) for image in imageData]
            categories = [image["category"] for image in imageData]
            if data["data"]["settings"]["modelType"] is "CNN":
                self.process_train_job(VGG16, data["id"], VGG16Settings, images, categories )
            else:
                self.process_train_job(ViTModel, data["id"], ViTSettings, images, categories)
        pass
        
    def get_classify_jobs(self):
        response = requests.get(self.url + "/checkJobs")
        if (response.status_code is 200):
            data = response.json()
            self.process_classify_job(data["data"]["model"], data["id"], data["data"]["imageUrl"])
        pass

    def process_train_job(self, model, modelid, settings, inputs, targets): #Initalized Model Input
        self.models[modelid] = model(settings)
        self.train_model(inputs, targets)
        requests.post(self.url + '/finishModel', json={"id": modelid})

    def process_classify_job(self, modelid, jobid, input):
        output = self.models[modelid].classify(input)
        requests.post(self.url + '/finishJob', json={"id": jobid, "output": output})