In [1]:
import os
import torch
import torch.nn as nn
from PIL import Image
from torchvision.utils import save_image
import wandb
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import Resize, Normalize, Compose
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.empty_cache()

In [3]:
class RiceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform or Compose([
            Resize((256, 256)),
            transforms.ToTensor(),
            Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.images, self.captions = self.load_dataset()
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    def load_dataset(self):
        images = []
        captions = []
        for class_dir in os.listdir(self.root_dir):
            class_path = os.path.join(self.root_dir, class_dir)
            for image_name in os.listdir(class_path):
                img_path = os.path.join(class_path, image_name)
                caption = self.get_caption(class_dir, image_name)
                images.append(img_path)
                captions.append(caption)
        return images, captions

    def get_caption(self, class_dir, image_name):
        if class_dir == "_BrownSpot":
            return "Rice plant with brown spot disease"
        elif class_dir == "_Hispa":
            return "Rice plant with Hispa disease"
        elif class_dir == "_LeafBlast":
            return "Rice plant with Leaf Blast disease"
        else:
            return "Healthy Rice Plant"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        caption = self.captions[idx]
        input_ids = self.tokenizer.encode(caption, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        attention_mask = (input_ids != self.tokenizer.pad_token_id).float()
        return img, input_ids, attention_mask

In [4]:
class TextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased'):
        super(TextEncoder, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        # Ensure input_ids and attention_mask are 2D
        if input_ids.dim() == 3:
            input_ids = input_ids.squeeze(1)
        if attention_mask.dim() == 3:
            attention_mask = attention_mask.squeeze(1)
        
        output = self.model(input_ids, attention_mask=attention_mask)
        return output.last_hidden_state[:, 0, :]

In [5]:
class Generator(nn.Module):
    def __init__(self, latent_dim, text_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.text_dim = text_dim
        self.init_size = 8  # Initial size of feature maps (256 = 8 * 2^5)
        self.l1 = nn.Linear(latent_dim + text_dim, 1024 * self.init_size ** 2)
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(1024),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(1024, 512, 3, stride=1, padding=1),
            nn.BatchNorm2d(512, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, text):
        input_tensor = torch.cat([noise, text], dim=1)
        out = self.l1(input_tensor)
        out = out.view(out.shape[0], 1024, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [6]:
class Discriminator(nn.Module):
    def __init__(self, text_dim):
        super(Discriminator, self).__init__()
        self.text_encoder = TextEncoder()

        self.image_model = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(512, 1024, 4, 2, 1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
        )
        
        # Calculate the output size of the image_model
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, 256, 256)
            dummy_output = self.image_model(dummy_input)
            self.image_features_size = dummy_output.view(1, -1).size(1)
        
        print(f"Image features size: {self.image_features_size}")
        
        self.fusion_layer = nn.Sequential(
            nn.Linear(self.image_features_size + text_dim, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, image, input_ids, attention_mask):
        if input_ids.dim() == 3:
            input_ids = input_ids.squeeze(1)
        if attention_mask.dim() == 3:
            attention_mask = attention_mask.squeeze(1)
        
        text_features = self.text_encoder(input_ids, attention_mask)
        
        image_features = self.image_model(image)
        image_features = image_features.view(image_features.size(0), -1)
        
        fused_features = torch.cat([image_features, text_features], dim=1)
        
        validity = self.fusion_layer(fused_features)
        return validity

Hypermparameter

In [7]:
latent_dim = 100
text_dim = 768
num_epochs = 500
batch_size = 64
learning_rate = 0.0002
beta1 = 0.5


In [8]:
generator = Generator(latent_dim, text_dim)
discriminator = Discriminator(text_dim)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

Image features size: 16384


Discriminator(
  (text_encoder): TextEncoder(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=T

In [9]:
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))


criterion = nn.BCELoss()


In [10]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image if needed
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

In [12]:
dataset = RiceDataset(r"D:\GuruGobindSinghIndrapasthaUniversity-Jatin\harjot_data\rice_images\rice_images")
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [11]:
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, input_ids, attention_mask):
    d_optimizer.zero_grad()

    # Ensure input_ids and attention_mask are 2D
    if input_ids.dim() == 3:
        input_ids = input_ids.squeeze(1)
    if attention_mask.dim() == 3:
        attention_mask = attention_mask.squeeze(1)

    # Train with real images
    real_validity = discriminator(real_images, input_ids, attention_mask)
    d_real_loss = criterion(real_validity, torch.ones_like(real_validity))

    # Train with fake images
    z = torch.randn(batch_size, latent_dim, device=device)
    text_embedding = discriminator.text_encoder(input_ids, attention_mask)
    fake_images = generator(z, text_embedding)
    fake_validity = discriminator(fake_images.detach(), input_ids, attention_mask)
    d_fake_loss = criterion(fake_validity, torch.zeros_like(fake_validity))

    # Total discriminator loss
    d_loss = (d_real_loss + d_fake_loss) / 2
    d_loss.backward()
    d_optimizer.step()

    return d_loss.item()

def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, input_ids, attention_mask):
    g_optimizer.zero_grad()

    # Ensure input_ids and attention_mask are 2D
    if input_ids.dim() == 3:
        input_ids = input_ids.squeeze(1)
    if attention_mask.dim() == 3:
        attention_mask = attention_mask.squeeze(1)

    # Generate fake images
    z = torch.randn(batch_size, latent_dim, device=device)
    text_embedding = discriminator.text_encoder(input_ids, attention_mask)
    fake_images = generator(z, text_embedding)

    # Try to fool the discriminator
    validity = discriminator(fake_images, input_ids, attention_mask)
    g_loss = criterion(validity, torch.ones_like(validity))

    g_loss.backward()
    g_optimizer.step()

    return g_loss.item()

In [12]:
import os

# Create directories if they don't exist
os.makedirs("saved_models_1", exist_ok=True)
os.makedirs("generated_images_1", exist_ok=True)

In [13]:
import csv
import os
from torchvision.utils import save_image

# Create directories if they don't exist
os.makedirs("generated_images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

def save_models(epoch, generator, discriminator):
    torch.save(generator.state_dict(), f"saved_models/generator_epoch_{epoch}.pth")
    torch.save(discriminator.state_dict(), f"saved_models/discriminator_epoch_{epoch}.pth")
    print(f"Models saved for epoch {epoch}")

def log_metrics(epoch, d_loss, g_loss):
    with open('training_log.csv', 'a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([epoch, d_loss, g_loss])


In [None]:

for epoch in range(num_epochs):
    d_losses = []
    g_losses = []

    for batch_idx, (real_images, input_ids, attention_mask) in enumerate(train_loader):
        real_images, input_ids, attention_mask = real_images.to(device), input_ids.to(device), attention_mask.to(device)

        d_loss = discriminator_train_step(real_images.size(0), discriminator, generator, d_optimizer, criterion, real_images, input_ids, attention_mask)
        g_loss = generator_train_step(real_images.size(0), discriminator, generator, g_optimizer, criterion, input_ids, attention_mask)

        d_losses.append(d_loss)
        g_losses.append(g_loss)

        if (batch_idx + 1) % 2== 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}")

    avg_d_loss = sum(d_losses) / len(d_losses)
    avg_g_loss = sum(g_losses) / len(g_losses)
    log_metrics(epoch + 1, avg_d_loss, avg_g_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg D_loss: {avg_d_loss:.4f}, Avg G_loss: {avg_g_loss:.4f}")

    # Generate and save images after every 5 epochs
    if (epoch + 1) % 20 == 0:
        with torch.no_grad():
            
            batch_size = 64  # Or whatever your current batch size is
            z = torch.randn(batch_size, latent_dim, device=device)
            
            # Create a batch of the same text
            text = ["A healthy rice plant"] * batch_size
            input_ids = dataset.tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt').input_ids.to(device)
            attention_mask = (input_ids != dataset.tokenizer.pad_token_id).float().to(device)
            
            text_embedding = discriminator.text_encoder(input_ids, attention_mask)
            generated_images = generator(z, text_embedding)
            save_image(generated_images, f"generated_images/epoch_{epoch+1}.png", nrow=8, normalize=True)
            print(f"Generated images saved as generated_images/epoch_{epoch+1}.png")

    # Save models every 10 epochs
    if (epoch + 1) % 50  == 0:
        save_models(epoch + 1, generator, discriminator)

print("Training completed.")


In [None]:
def generate_image_from_text(text, filename):
    generator.eval()
    with torch.no_grad():
        input_ids = dataset.tokenizer.encode(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt').to(device)
        attention_mask = (input_ids != dataset.tokenizer.pad_token_id).float().to(device)
        z = torch.randn(1, latent_dim, device=device)
        text_embedding = discriminator.text_encoder(input_ids, attention_mask)
        generated_image = generator(z, text_embedding)
        save_image(generated_image, f"generated_images/{filename}.png", normalize=True)

# Example usage
generate_image_from_text("A healthy rice plant", "healthy_rice_plant")

In [15]:
def load_models(generator, discriminator, epoch):
    generator.load_state_dict(torch.load(f"saved_models/generator_epoch_{epoch}.pth"))
    discriminator.load_state_dict(torch.load(f"saved_models/discriminator_epoch_{epoch}.pth"))
    print(f"Models loaded from epoch {epoch}")
    return generator, discriminator

In [16]:
# Load the last saved models
last_epoch = 300 # Replace with the epoch number of your last saved model
generator, discriminator = load_models(generator, discriminator, last_epoch)

# Reset optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# Continue training
for epoch in range(last_epoch, last_epoch + num_epochs):
    d_losses = []
    g_losses = []

    for batch_idx, (real_images, input_ids, attention_mask) in enumerate(train_loader):
        real_images, input_ids, attention_mask = real_images.to(device), input_ids.to(device), attention_mask.to(device)

        d_loss = discriminator_train_step(real_images.size(0), discriminator, generator, d_optimizer, criterion, real_images, input_ids, attention_mask)
        g_loss = generator_train_step(real_images.size(0), discriminator, generator, g_optimizer, criterion, input_ids, attention_mask)

        d_losses.append(d_loss)
        g_losses.append(g_loss)

        if (batch_idx + 1) % 2 == 0:
            print(f"Epoch [{epoch+1}/{last_epoch + num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}")

    avg_d_loss = sum(d_losses) / len(d_losses)
    avg_g_loss = sum(g_losses) / len(g_losses)
    log_metrics(epoch + 1, avg_d_loss, avg_g_loss)
    print(f"Epoch [{epoch+1}/{last_epoch + num_epochs}], Avg D_loss: {avg_d_loss:.4f}, Avg G_loss: {avg_g_loss:.4f}")

    # Generate and save images after every 20 epochs
    if (epoch + 1) % 20 == 0:
        with torch.no_grad():
            batch_size = 64
            z = torch.randn(batch_size, latent_dim, device=device)
            text = ["A healthy rice plant"] * batch_size
            input_ids = dataset.tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt').input_ids.to(device)
            attention_mask = (input_ids != dataset.tokenizer.pad_token_id).float().to(device)
            text_embedding = discriminator.text_encoder(input_ids, attention_mask)
            generated_images = generator(z, text_embedding)
            save_image(generated_images, f"generated_images/epoch_{epoch+1}.png", nrow=8, normalize=True)
            print(f"Generated images saved as generated_images/epoch_{epoch+1}.png")

    # Save models every 50 epochs
    if (epoch + 1) % 50 == 0:
        save_models(epoch + 1, generator, discriminator)

print("Training completed.")

Models loaded from epoch 300
Epoch [301/800], Step [2/53], D_loss: 0.0032, G_loss: 6.1806
Epoch [301/800], Step [4/53], D_loss: 0.0116, G_loss: 6.0997
Epoch [301/800], Step [6/53], D_loss: 0.0758, G_loss: 6.7526
Epoch [301/800], Step [8/53], D_loss: 0.3016, G_loss: 8.6380
Epoch [301/800], Step [10/53], D_loss: 0.1457, G_loss: 8.1033
Epoch [301/800], Step [12/53], D_loss: 0.0762, G_loss: 7.1117
Epoch [301/800], Step [14/53], D_loss: 0.0200, G_loss: 6.8384
Epoch [301/800], Step [16/53], D_loss: 0.0085, G_loss: 8.7462
Epoch [301/800], Step [18/53], D_loss: 0.0246, G_loss: 6.0262
Epoch [301/800], Step [20/53], D_loss: 0.0467, G_loss: 9.0268
Epoch [301/800], Step [22/53], D_loss: 0.0012, G_loss: 9.3924
Epoch [301/800], Step [24/53], D_loss: 0.0020, G_loss: 8.3505
Epoch [301/800], Step [26/53], D_loss: 0.0224, G_loss: 7.4477
Epoch [301/800], Step [28/53], D_loss: 0.0225, G_loss: 12.0038
Epoch [301/800], Step [30/53], D_loss: 0.0111, G_loss: 1.8972
Epoch [301/800], Step [32/53], D_loss: 0.164