In [None]:
!pip uninstall -y clip
!pip uninstall -y openai-clip
!pip install git+https://github.com/openai/CLIP.git

Found existing installation: clip 1.0
Uninstalling clip-1.0:
  Successfully uninstalled clip-1.0
[0mCollecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-9mzbu5cn
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-9mzbu5cn
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=f518667fba4b45fdcf6e5bd300f7cf54d38aba4a4f1a8d12477c49421e263079
  Stored in directory: /tmp/pip-ephem-wheel-cache-t5oiyovi/wheels/35/3e/df/3d24cbfb3b6a06f17a2bfd7d1138900d4365d9028aa8f6e92f
Successfully built clip
Installing collected packages: clip
Successfully installed clip-1.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import cv2
from pathlib import Path
from PIL import Image
import os
from google.colab import drive
from google.colab import files
import subprocess
import sys
import clip

In [None]:
def mount_google_drive():
    """Mount Google Drive to access files"""
    try:
        drive.mount('/content/drive', force_remount=False)
        print("Google Drive mounted successfully!")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")

# Call this at the start
mount_google_drive()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully!


In [None]:
class CVAE(nn.Module):
    """Conditional VAE - matching your teammate's architecture"""
    def __init__(self, latent_dim=16, label_emb_dim=16):
        super().__init__()

        self.label_emb = nn.Embedding(2, label_emb_dim)

        # Encoder CNN
        self.enc_conv1 = nn.Conv2d(1, 32, 4, 2, 1)
        self.enc_conv2 = nn.Conv2d(32, 64, 4, 2, 1)
        self.enc_conv3 = nn.Conv2d(64, 128, 4, 2, 1)
        self.enc_conv4 = nn.Conv2d(128, 256, 4, 2, 1)
        self.enc_conv5 = nn.Conv2d(256, 512, 4, 2, 1)

        self.fc_enc = nn.Linear(512*8*8 + label_emb_dim, 512)
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)

        self.fc_dec1 = nn.Linear(latent_dim + label_emb_dim, 512)
        self.fc_dec2 = nn.Linear(512, 512*8*8)

        self.dec_deconv1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.dec_deconv2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.dec_deconv3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.dec_deconv4 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.dec_out = nn.ConvTranspose2d(32, 1, 4, 2, 1)

    def encode(self, x, y):
        h = F.relu(self.enc_conv1(x))
        h = F.relu(self.enc_conv2(h))
        h = F.relu(self.enc_conv3(h))
        h = F.relu(self.enc_conv4(h))
        h = F.relu(self.enc_conv5(h))

        h = h.view(h.size(0), -1)
        h = torch.cat([h, self.label_emb(y)], dim=1)
        h = F.relu(self.fc_enc(h))

        mu = self.fc_mu(h)
        logvar = torch.clamp(self.fc_logvar(h), -8, 8)
        return mu, logvar

    def decode(self, z, y):
        h = torch.cat([z, self.label_emb(y)], dim=1)
        h = F.relu(self.fc_dec1(h))
        h = F.relu(self.fc_dec2(h))
        h = h.view(h.size(0), 512, 8, 8)

        h = F.relu(self.dec_deconv1(h))
        h = F.relu(self.dec_deconv2(h))
        h = F.relu(self.dec_deconv3(h))
        h = F.relu(self.dec_deconv4(h))

        return torch.sigmoid(self.dec_out(h))

    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        recon = self.decode(z, y)
        return recon, mu, logvar


class CVAELoader:
    def __init__(self, model_path, device='cuda'):
        self.device = device
        self.model = self._load_cvae('/content/drive/MyDrive/solar/model_b1_2')

    def _load_cvae(self, model_path):
        """Load pre-trained CVAE model"""
        model_path = Path(model_path)

        if not model_path.exists():
            raise FileNotFoundError(f"Path does not exist: {model_path}")

        pt_files = [f for f in model_path.iterdir() if f.suffix == '.pt']

        if not pt_files:
            raise FileNotFoundError(f"No .pt files found in {model_path}")

        pt_file = pt_files[0]
        print(f"Loading CVAE from {pt_file}")

        model = CVAE(latent_dim=2, label_emb_dim=16).to(self.device)

        # Load checkpoint
        checkpoint = torch.load(str(pt_file), map_location=self.device)

        if isinstance(checkpoint, dict):
            if 'model' in checkpoint:
                state_dict = checkpoint['model']
            elif 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
        else:
            state_dict = checkpoint

        try:
            model.load_state_dict(state_dict)
            print("✓ Model loaded with strict=True")
        except RuntimeError as e:
            print(f"Warning: {e}")
            print("Loading with strict=False")
            model.load_state_dict(state_dict, strict=False)

        model.eval()
        return model

    def decode(self, z, label):
        """Decode latent to image for a given label"""
        with torch.no_grad():
            y = torch.full((z.shape[0],), label, dtype=torch.long, device=self.device)
            return self.model.decode(z, y)

In [None]:
class SolarFlareDataset(Dataset):
    def __init__(self, csv_path, image_dir, image_size=224):
        self.df = pd.read_csv('/content/drive/MyDrive/solar/labels_with_captions.csv')
        self.image_dir = Path('/content/drive/MyDrive/solar/preprocessed_hourly_all')
        self.image_size = image_size
        self.valid_indices = []

        for idx, row in self.df.iterrows():
            try:
                image_filename = row.get('SHARP_FILE', None) or row.iloc[1]
                image_path = self.image_dir / str(image_filename)
                if image_path.exists():
                    self.valid_indices.append(idx)
            except:
                pass

        print(f"Found {len(self.valid_indices)} valid images out of {len(self.df)}")

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711]
            )
        ])

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        actual_idx = self.valid_indices[idx]
        row = self.df.iloc[actual_idx]

        caption = str(row.get('caption', row.iloc[0]))

        if pd.isna(caption) or caption == 'nan':
            caption = "A solar flare image"

        image_filename = row.get('Image filename', row.iloc[1])
        image_path = self.image_dir / str(image_filename)

        try:
            img = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {image_path}: {e}")
            img = Image.new('RGB', (self.image_size, self.image_size))

        img_tensor = self.transform(img)

        return {
            'image': img_tensor,
            'caption': caption,
            'flare_label': row.get('Flare label', 0)
        }


In [None]:
import torch
import torch.nn.functional as F
import clip
import numpy as np
import cv2
from pathlib import Path
from PIL import Image, ImageEnhance

class CLIPGuidedGenerator:
    def __init__(self, cvae_model, device='cuda'):
        self.device = device
        self.cvae = cvae_model
        self.latent_dim = cvae_model.fc_mu.out_features
        print(f"CVAE latent dimension detected: {self.latent_dim}")

        print("Loading pre-trained CLIP model (ViT-B/32)...")
        self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=device)
        self.clip_model.eval()

    def generate(self, text_prompt, label=1, steps=300, lr=0.1, clip_weight=1.0):
        """Generate solar flare image from text prompt using CLIP guidance."""

        label_tensor = torch.tensor([label], dtype=torch.long, device=self.device)

        z = torch.randn(1, self.latent_dim, device=self.device, requires_grad=True)
        optimizer = torch.optim.Adam([z], lr=lr)

        with torch.no_grad():
            text_tokens = clip.tokenize([text_prompt]).to(self.device)
            text_features = self.clip_model.encode_text(text_tokens)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        best_loss = float('inf')
        best_z = z.clone().detach()

        print(f"Generating image for prompt: '{text_prompt}'")
        print(f"Label: {'FL (Flare)' if label==1 else 'NF (Non-Flare)'}")

        for step in range(steps):
            optimizer.zero_grad()

            img = self.cvae.decode(z, label_tensor)

            if img.shape[1] == 1:
                img_rgb = img.repeat(1, 3, 1, 1)
            else:
                img_rgb = img

            img_rgb = F.interpolate(img_rgb, size=(224,224), mode='bilinear', align_corners=False)


            image_features = self.clip_model.encode_image(img_rgb)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            clip_loss = -clip_weight * torch.nn.functional.cosine_similarity(text_features, image_features).mean()
            reg_loss = 1e-4 * z.abs().mean()
            loss = clip_loss + reg_loss

            loss.backward()
            optimizer.step()

            if loss.item() < best_loss:
                best_loss = loss.item()
                best_z = z.clone().detach()

            if (step + 1) % 50 == 0:
                print(f"Step {step+1}/{steps}, Loss: {loss.item():.6f}")

        with torch.no_grad():
            final_img = self.cvae.decode(best_z, label_tensor)

        return final_img

    def save_image(self, tensor, save_path):
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)

        img = tensor.squeeze().detach().cpu().numpy()
        if len(img.shape) == 2:
            img = np.stack([img]*3, axis=0)
        elif img.shape[0] == 1:
            img = np.repeat(img, 3, axis=0)

        img = np.transpose(img, (1,2,0))
        img = (img * 255).clip(0,255).astype(np.uint8)


        pil_img = Image.fromarray(img)


        pil_img = ImageEnhance.Contrast(pil_img).enhance(2.0)
        pil_img = ImageEnhance.Color(pil_img).enhance(1.5)

        pil_img.save(save_path)
        print(f"✓ Image saved with enhancement to {save_path}")

In [None]:
if __name__ == "__main__":
    import sys

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    cvae_model_path = "/content/drive/MyDrive/solar/model_b1_2"
    output_dir = Path("/content/drive/MyDrive/solar/clip")
    output_dir.mkdir(parents=True, exist_ok=True)

    try:
        cvae_loader = CVAELoader(cvae_model_path, device=device)
        cvae = cvae_loader.model
        print("✓ CVAE model loaded successfully")
    except Exception as e:
        print(f"Error loading CVAE: {e}")
        sys.exit(1)

    generator = CLIPGuidedGenerator(cvae, device=device)
    print("✓ Generator initialized")

    prompts = [
    'Nan solar flare on 2010-05-01. Active region longitudes -82 to -74. Magnetic flux ratio 0.0'
]

    for i, prompt in enumerate(prompts):
        print(f"\n[{i+1}/{len(prompts)}] Generating image...")
        try:
            img = generator.generate(prompt, label=1, steps=300, lr=0.1, clip_weight=1.0)
            output_path = output_dir / f"testgenerated_flare_{i+1}.png"
            generator.save_image(img, output_path)
        except Exception as e:
            print(f"Error generating image: {e}")

    print("\nGeneration complete! Images saved to Google Drive.")

Using device: cpu
Loading CVAE from /content/drive/MyDrive/solar/model_b1_2/model_b1_2.pt
✓ Model loaded with strict=True
✓ CVAE model loaded successfully
CVAE latent dimension detected: 2
Loading pre-trained CLIP model (ViT-B/32)...
✓ Generator initialized

[1/1] Generating image...
Generating image for prompt: 'Nan solar flare on 2010-05-01. Active region longitudes -82 to -74. Magnetic flux ratio 0.0'
Label: FL (Flare)
Step 50/300, Loss: -0.181757
Step 100/300, Loss: -0.181757
Step 150/300, Loss: -0.181758
Step 200/300, Loss: -0.181758
Step 250/300, Loss: -0.181757
Step 300/300, Loss: -0.181757
✓ Image saved with enhancement to /content/drive/MyDrive/solar/clip/testgenerated_flare_1.png

Generation complete! Images saved to Google Drive.
