In [None]:
from datasets import load_dataset
dataset = load_dataset("22-24/Final", split = "train")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import datasets  # Hugging Face datasets library
import io
from tqdm import tqdm  # Import tqdm

# Define custom dataset class
class ImageSuperResolutionDataset(Dataset):
    def __init__(self, hf_dataset, transform_input, transform_target):
        self.dataset = hf_dataset
        self.transform_input = transform_input
        self.transform_target = transform_target

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        image = data  # Replace "image_path" with your dataset's image key
        input_image = self.transform_input(image)
        target_image = self.transform_target(image)
        return input_image, target_image

# Custom transform to add JPEG encoding
class JpegEncodingOnly:
    def __init__(self, quality=85):  # Default JPEG quality is 75
        self.quality = quality

    def __call__(self, img):
        # Encode image to JPEG
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG", quality=self.quality)
        buffer.seek(0)
        img = Image.open(buffer)  # Return JPEG-encoded bytes
        img.load()
        return img

# Transformation pipeline
transform_input = transforms.Compose([
    JpegEncodingOnly(),
    transforms.Resize((32, 32)),           # Resize to 32x32 (Optional if not decoding back)
    transforms.ToTensor(),                 # Convert to tensor (Not valid on byte stream)
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize (Not valid on byte stream)
])

transform_target = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

hf_dataset = dataset["image"] 
custom_dataset = ImageSuperResolutionDataset(hf_dataset, transform_input, transform_target)
dataloader = DataLoader(custom_dataset, batch_size=16, shuffle=True)

# Load model
processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr")
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = torch.nn.DataParallel(model)  # Wrap the model
model = model.to(device)


criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 1000
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    
    with tqdm(dataloader, unit="batch", desc=f"Epoch {epoch+1}/{epochs}") as pbar:
        for input_image, target_image in pbar:
            input_image = input_image.to(device)
            target_image = target_image.to(device)

            optimizer.zero_grad()
            outputs = model(pixel_values=input_image)
            loss = criterion(outputs.reconstruction, target_image)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

           
            pbar.set_postfix(loss=loss.item())

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss/len(dataloader)}")


In [None]:
import torch
from PIL import Image
import matplotlib.pyplot as plt

# Inference function
def infer(input_image_path, model, processor, transform_input, device="cuda"):
    
    input_image = Image.open(input_image_path).convert("RGB")
    input_image = transform_input(input_image).unsqueeze(0).to(device)  

    
    model.eval()  
    with torch.no_grad():
        outputs = model(pixel_values=input_image)
        output_image = outputs.reconstruction.squeeze().cpu().clamp(0, 1)  

    # Convert tensor to PIL Image for visualization
    output_image_pil = transforms.ToPILImage()(output_image)
    return output_image_pil

# Load a sample image for inference
input_image_path = " "  # Provide the path to your test image
output_image = infer(input_image_path, model, processor, transform_input)

# Display the input and output images
input_image = Image.open(input_image_path).convert("RGB")
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Input Image")
plt.imshow(input_image)
plt.axis("off")
plt.subplot(1, 2, 2)
plt.title("Output Image (Super-Resolution)")
plt.imshow(output_image)
plt.axis("off")
plt.show()


In [None]:
from transformers import Swin2SRForImageSuperResolution, AutoImageProcessor

# Save model and processor to Hugging Face Hub
def save_model_to_huggingface(model, processor, repo_name="swin2SR-custom-model"):
    
    model.save_pretrained(f"./{repo_name}")
    processor.save_pretrained(f"./{repo_name}")

    
    model.push_to_hub(repo_name)
    processor.push_to_hub(repo_name)


repo_name = " " 

# Save model and processor
save_model_to_huggingface(model, processor, repo_name)
