In [1]:
# LSGAN

In [2]:
!pip install -q torch torchvision datasets pillow scikit-learn matplotlib pandas tabulate tensorflow seaborn

In [3]:
import os
import numpy as np
from collections import Counter
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
from torchvision import transforms, utils
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

# Load dataset and identify indices
raw_ds = load_dataset("yuighj123/covid-19-classification")
class_names = raw_ds['train'].features['label'].names
covid_idx = class_names.index('Covid')
normal_idx = class_names.index('Normal')
viral_idx = class_names.index('Viral Pneumonia')
num_classes = len(class_names)

def map_orig_label(l):
    return l  

# Real data Dataset (all 3 classes)
class RealDataset(Dataset):
    def __init__(self, hf_data, transform=None):
        self.images = hf_data['image']
        self.labels = hf_data['label']
        self.transform = transform
    def __len__(self): return len(self.labels)
    def __getitem__(self, idx):
        img = self.images[idx]
        if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img))
        img = img.convert('RGB')
        if self.transform: img = self.transform(img)
        return img, self.labels[idx]

# Transforms
tfm = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
real_ds = RealDataset(raw_ds['train'], transform=tfm)

# Train LSGAN on minority classes only (Normal & Viral)
latent_dim = 100
img_shape = (3,64,64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define simple LSGAN
class G(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim,128), nn.ReLU(True),
            nn.Linear(128,256), nn.BatchNorm1d(256), nn.ReLU(True),
            nn.Linear(256,512), nn.BatchNorm1d(512), nn.ReLU(True),
            nn.Linear(512,np.prod(img_shape)), nn.Tanh()
        )
    def forward(self,z): return self.net(z).view(-1,*img_shape)

class D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(np.prod(img_shape),512), nn.LeakyReLU(0.2),
            nn.Linear(512,256), nn.LeakyReLU(0.2),
            nn.Linear(256,1)
        )
    def forward(self,x): return self.net(x.view(x.size(0),-1))

G_model = G().to(device)
D_model = D().to(device)
optG = optim.Adam(G_model.parameters(), lr=2e-4, betas=(0.5,0.999))
optD = optim.Adam(D_model.parameters(), lr=2e-4, betas=(0.5,0.999))
criterion = nn.MSELoss()

def get_minority_loader():
    
    subset = torch.utils.data.Subset(real_ds,
        [i for i,(img,lbl) in enumerate(real_ds) if lbl in [normal_idx, viral_idx]])
    return DataLoader(subset, batch_size=64, shuffle=True)

loader = get_minority_loader()
for epoch in range(10):
    for imgs,l in loader:
        bs=imgs.size(0)
        real = imgs.to(device)
        valid = torch.ones(bs,1,device=device)
        fake_label = torch.zeros(bs,1,device=device)
        # train D
        z = torch.randn(bs,latent_dim,device=device)
        fake = G_model(z)
        lossD = 0.5*(criterion(D_model(real),valid)+criterion(D_model(fake.detach()),fake_label))
        optD.zero_grad(); lossD.backward(); optD.step()
        # train G
        lossG = criterion(D_model(fake),valid)
        optG.zero_grad(); lossG.backward(); optG.step()
    print(f"Epoch {epoch+1}/10 | D:{lossD.item():.4f} G:{lossG.item():.4f}")

#  Generate synthetic for classes 1 &2 to balance all three
counts = Counter([lbl for _,lbl in real_ds])
max_cnt = max(counts.values())
generated = []
for target in [normal_idx, viral_idx]:
    need = max_cnt - counts[target]
    if need<=0: continue
    z = torch.randn(need, latent_dim, device=device)
    with torch.no_grad(): imgs = G_model(z).cpu()
    for img in imgs: generated.append((img, target))

# Save and wrap generated
gen_dir = 'gen_imgs'; os.makedirs(gen_dir, exist_ok=True)
for i,(img,lbl) in enumerate(generated):
    utils.save_image(img, f"{gen_dir}/img_{i}_lbl{lbl}.png", normalize=True)
class SynthDataset(Dataset):
    def __init__(self, folder, tfm):
        self.paths = [os.path.join(folder,f) for f in os.listdir(folder)]
        self.tfm = tfm
    def __len__(self): return len(self.paths)
    def __getitem__(self,idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        label = int(self.paths[idx].split('_lbl')[1].split('.png')[0])
        return self.tfm(img), label
synth_ds = SynthDataset(gen_dir, tfm)

#  Combine real + synthetic, split 70/15/15
full_ds = ConcatDataset([real_ds, synth_ds])
train_n = int(0.7*len(full_ds))
val_n = int(0.15*len(full_ds))
test_n = len(full_ds)-train_n-val_n
train_ds, val_ds, test_ds = random_split(full_ds, [train_n,val_n,test_n])
load = lambda ds: DataLoader(ds, batch_size=64, shuffle=True)
train_ld, val_ld, test_ld = load(train_ds), load(val_ds), load(test_ds)

# Train CNN
class CNN3(nn.Module):
    def __init__(self):
        super().__init__()
        self.feat = nn.Sequential(
            nn.Conv2d(3,32,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,1,1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.clf = nn.Sequential(nn.Flatten(), nn.Linear(128*8*8,128), nn.ReLU(), nn.Linear(128,num_classes))
    def forward(self,x): return self.clf(self.feat(x))

model = CNN3().to(device)
opt = optim.Adam(model.parameters(),lr=1e-3)
crit = nn.CrossEntropyLoss()
for epoch in range(5):
    model.train()
    for x,y in train_ld:
        x,y = x.to(device),y.to(device)
        opt.zero_grad(); crit(model(x),y).backward(); opt.step()
    print(f"Train epoch {epoch+1}/5")
# Evaluate on test
model.eval()
all_p, all_t=[],[]
with torch.no_grad():
    for x,y in test_ld:
        preds = model(x.to(device)).argmax(1).cpu()
        all_p.append(preds); all_t.append(y)
all_p,all_t=torch.cat(all_p),torch.cat(all_t)
print(classification_report(all_t,all_p,labels=[covid_idx,normal_idx,viral_idx],
                             target_names=['COVID-19','Normal','Viral Pneumonia'],digits=4))



Epoch 1/10 | D:0.0763 G:0.8710
Epoch 2/10 | D:0.0212 G:1.1340
Epoch 3/10 | D:0.0148 G:1.1369
Epoch 4/10 | D:0.0118 G:1.3158
Epoch 5/10 | D:0.0179 G:1.6531
Epoch 6/10 | D:0.0794 G:1.2651
Epoch 7/10 | D:0.0216 G:1.3779
Epoch 8/10 | D:0.0069 G:1.4262
Epoch 9/10 | D:0.0058 G:1.3034
Epoch 10/10 | D:0.1250 G:1.2882
Train epoch 1/5
Train epoch 2/5
Train epoch 3/5
Train epoch 4/5
Train epoch 5/5
                 precision    recall  f1-score   support

       COVID-19     0.8571    0.9231    0.8889        13
         Normal     0.9091    0.5882    0.7143        17
Viral Pneumonia     0.7692    0.9524    0.8511        21

       accuracy                         0.8235        51
      macro avg     0.8452    0.8212    0.8181        51
   weighted avg     0.8383    0.8235    0.8151        51

