In [1]:
! pip install torchinfo



# Vision Model

In [2]:
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange

class CNN_branch(nn.Module):
    def __init__(self,channels):
        super().__init__()
        self.net = nn.Sequential(nn.Conv3d(channels,channels,
                                           kernel_size=3,padding=1,
                                           groups=channels),
                            nn.Conv3d(channels,channels,kernel_size=1),
                            nn.BatchNorm3d(channels),
                            nn.ReLU6(),
                            )
    def forward(self,x):
        return x + self.net(x)

class Attention(nn.Module):
    def __init__(self,heads,patch,drop):
        super().__init__()
        self.heads = heads
        self.patch=patch
        self.scale = patch**-1
        self.conv_project = nn.Sequential(nn.Conv3d(1,3*heads,
                                                    kernel_size=(3,3,1),
                                                    padding=(1,1,0),
                                                    bias=False),
                                          Rearrange('b h x y s -> b s (h x y)'),
                                          nn.Dropout(drop))
        self.reduce_k = nn.Conv2d(self.heads,self.heads,
                                  kernel_size=(3,1),padding=(1,0),stride=(4,1),
                                  groups=self.heads,bias=False)
        self.reduce_v = nn.Conv2d(self.heads,self.heads,
                                  kernel_size=(3,1),padding=(1,0),stride=(4,1),
                                  groups=self.heads,bias=False)
        self.conv_out = nn.Sequential(nn.Conv3d(in_channels=heads,
                                                out_channels=1,
                                                kernel_size=(3,3,1),
                                                padding=(1,1,0),bias=False),
                                      nn.Dropout(drop),
                                      Rearrange('b c x y s-> b c s x y'),
                                      nn.LayerNorm((patch,patch)),
                                      Rearrange('b c s x y->b c x y s')
                                      )
    def forward(self,x):
        qkv = self.conv_project(x).chunk(3,dim=-1)
        q,k,v = map(lambda a: rearrange(a,'b s (h d) -> b h s d',h=self.heads),
                    qkv)
        k = self.reduce_k(k)
        dots = torch.einsum('bhid,bhjd->bhij',q,k) * self.scale
        attn = dots.softmax(dim=-1)
        v = self.reduce_v(v)
        out = torch.einsum('bhij,bhjd->bhid',attn,v)
        out = rearrange(out,'b c s (x y) -> b c x y s ',
                        x=self.patch,y=self.patch)
        out = self.conv_out(out)
        return out

class ConvTE(nn.Module):
    def __init__(self,heads,patch,drop):
        super().__init__()
        self.attention = Attention(heads,patch,drop)
        self.ffn = nn.Sequential(nn.Conv3d(in_channels=1,out_channels=1,
                                           kernel_size=(3,3,1),
                                           padding=(1,1,0),
                                           bias=False),
                                 nn.ReLU6(),
                                 nn.Dropout(drop)
                                 )
    def forward(self,x):
        x = x + self.attention(x)
        x = x + self.ffn(x)
        return x

class DBCT(nn.Module):
    def __init__(self,channels,patch,heads,drop,fc_dim,band_reduce):
        super().__init__()
        self.cnn_branch = CNN_branch(channels)
        self.convte_branch = nn.Sequential(nn.Conv3d(channels,1,
                                                     kernel_size=(1,1,7),
                                                     padding=(0,0,3),
                                                     stride=(1,1,1)),
                                           ConvTE(heads,patch,drop)
                                           )
        self.cnn_out = nn.Sequential(nn.Conv3d(channels,channels,
                                                 kernel_size=(3,3,
                                                              band_reduce),
                                                 padding=(1,1,0),
                                                 groups=channels),
                                       nn.BatchNorm3d(channels),
                                       nn.ReLU6()
                                       )
        self.te_out = nn.Sequential(nn.Conv3d(1,channels,
                                                  kernel_size=(3,3,
                                                               band_reduce),
                                                  padding=(1,1,0)),
                                        nn.BatchNorm3d(channels),
                                        nn.ReLU6()
                                        )
        self.out = nn.Sequential(nn.Conv3d(2*channels ,fc_dim,kernel_size=1),
                                nn.BatchNorm3d(fc_dim),
                                nn.ReLU6()
                                )
    def forward(self,x):
        x_cnn = self.cnn_branch(x)
        x_te = self.convte_branch(x)
        cnn_out = self.cnn_out(x_cnn)
        te_out = self.te_out(x_te)
        out = self.out(torch.cat((cnn_out,te_out),dim=1))
        return out

class MSpeFE(nn.Module):
    def __init__(self,channels):
        super().__init__()
        self.c = channels // 4
        self.spectral1 = nn.Sequential(nn.Conv3d(self.c,self.c,
                                                 kernel_size=(1,1,3),
                                                 padding=(0,0,1),
                                                 groups=self.c),
                                                 nn.BatchNorm3d(self.c),
                                                 nn.ReLU6()
                                                 )
        self.spectral2 = nn.Sequential(nn.Conv3d(self.c,self.c,
                                                 kernel_size=(1,1,7),
                                                 padding=(0,0,3),
                                                 groups=self.c),
                                                 nn.BatchNorm3d(self.c),
                                                 nn.ReLU6()
                                                 )
        self.spectral3 = nn.Sequential(nn.Conv3d(self.c,self.c,
                                                 kernel_size=(1,1,11),
                                                 padding=(0,0,5),
                                                 groups=self.c),
                                                 nn.BatchNorm3d(self.c),
                                                 nn.ReLU6()
                                                 )
        self.spectral4 = nn.Sequential(nn.Conv3d(self.c,self.c,
                                                 kernel_size=(1,1,15),
                                                 padding=(0,0,7),
                                                 groups=self.c),
                                                 nn.BatchNorm3d(self.c),
                                                 nn.ReLU6()
                                                 )

    def forward(self,x):
        x1 = self.spectral1(x[:,0:self.c,:])
        x2 = self.spectral2(x[:,self.c:2*self.c,:])
        x3 = self.spectral3(x[:,2*self.c:3*self.c,:])
        x4 = self.spectral4(x[:,3*self.c:,:])
        mspe = torch.cat((x1,x2,x3,x4),dim=1)
        return mspe

class DBCTNet(nn.Module):
    def __init__(self,channels=16,patch=9,bands=270,
                 fc_dim=16,heads=2,drop=0.1):
        super().__init__()
        self.band_reduce = (bands - 7) // 2 + 1
        self.stem = nn.Conv3d(1,channels,kernel_size=(1,1,7),
                                            padding=0,stride=(1,1,2))
        self.mspefe = MSpeFE(channels)

        self.dbct = DBCT(channels,patch,heads,drop,fc_dim,self.band_reduce)

        # self.fc = nn.Sequential(nn.AdaptiveAvgPool3d((1,1,1)),
        #                         nn.Flatten(),
        #                         nn.Linear(fc_dim, num_class)
        #                         )

    def forward(self,x):
        # x.shape = [batch_size,1,patch_size,patch_size,spectral_bands]
        b,_,_,_,_ = x.shape
        x = self.stem(x)
        x = self.mspefe(x)
        feature = self.dbct(x)
        return feature
        # return self.fc(feature)

import torch
from torchinfo import summary

# Define the model
model = DBCTNet(bands=200)
device = torch.device("cuda:0")
model = model.to(device)
model.eval()

# Define input tensor
input_tensor = torch.randn(4, 1, 9, 9, 200).cuda()

# Print model summary
summary(model, input_size=(4, 1, 9, 9, 200), col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
DBCTNet                                       [4, 1, 9, 9, 200]         [4, 16, 9, 9, 1]          --
├─Conv3d: 1-1                                 [4, 1, 9, 9, 200]         [4, 16, 9, 9, 97]         128
├─MSpeFE: 1-2                                 [4, 16, 9, 9, 97]         [4, 16, 9, 9, 97]         --
│    └─Sequential: 2-1                        [4, 4, 9, 9, 97]          [4, 4, 9, 9, 97]          --
│    │    └─Conv3d: 3-1                       [4, 4, 9, 9, 97]          [4, 4, 9, 9, 97]          16
│    │    └─BatchNorm3d: 3-2                  [4, 4, 9, 9, 97]          [4, 4, 9, 9, 97]          8
│    │    └─ReLU6: 3-3                        [4, 4, 9, 9, 97]          [4, 4, 9, 9, 97]          --
│    └─Sequential: 2-2                        [4, 4, 9, 9, 97]          [4, 4, 9, 9, 97]          --
│    │    └─Conv3d: 3-4                       [4, 4, 9, 9, 97]          [4, 4, 9, 9, 9

# Text model

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer

VISION_MODEL_INFO = {
    'DBCTNet': {
        'out_dim': 1296,
        'model': DBCTNet,
    },
}

class Model(nn.Module):
    def __init__(self, num_classes, vision_encoder_name, bands, merging_method, max_length=64):
        super(Model, self).__init__()

        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.max_length = max_length

        for param in self.bert.parameters():
            param.requires_grad = False

        self.bert_fc = nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size)

        self.vision_encoder = VISION_MODEL_INFO[vision_encoder_name]['model'](bands=bands)
        vision_features = VISION_MODEL_INFO[vision_encoder_name]['out_dim']
        self.vision_fc = nn.Linear(vision_features, self.bert.config.hidden_size)

        self.merging_method = merging_method

        if merging_method == 'CONCAT':
            self.fc = nn.Linear(self.bert.config.hidden_size * 2, num_classes)
        elif merging_method in ['PWA', 'PWM']:
            self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def tokenize(self, text):
        tokens = self.tokenizer(
            text, padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt"
        )
        return tokens["input_ids"], tokens["attention_mask"]

    def forward_text_features(self, text):
        input_ids, attention_mask = self.tokenize(text)
        input_ids, attention_mask = input_ids.to(next(self.parameters()).device), attention_mask.to(next(self.parameters()).device)

        with torch.no_grad():
            text_out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]

        text_emb = self.bert_fc(text_out)
        text_out = text_out + text_emb

        return text_out

    def forward_vision_features(self, image):
        vision_out = self.vision_encoder(image)
        vision_out = vision_out.view(vision_out.size(0), -1)
        vision_out = self.vision_fc(vision_out)
        return vision_out

    def forward(self, image, text):
        text_emb = self.forward_text_features(text)
        vision_emb = self.forward_vision_features(image)

        if self.merging_method == 'CONCAT':
            cls_output = torch.cat((text_emb, vision_emb), dim=1)
        elif self.merging_method == 'PWA':
            cls_output = text_emb + vision_emb
        elif self.merging_method == 'PWM':
            cls_output = text_emb * vision_emb

        return self.fc(cls_output)

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 9
vision_encoder_name = 'DBCTNet'
bands = 270
merging_method='CONCAT'  # 'PWA', 'PWM'

model = Model(
    num_classes=num_classes,
    vision_encoder_name=vision_encoder_name,
    bands=bands,
    merging_method=merging_method
    ).to(device)

# Example text input
text = "This is a sample text for classification."
image = torch.randn(1, 1, 9, 9, 270).cuda()
output = model(image, text)
output.shape

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

torch.Size([1, 9])

# Data processing Pipeline

In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import f1_score, confusion_matrix, cohen_kappa_score, precision_score, recall_score
import numpy as np
import random
import sys
sys.path.append("/kaggle/input/util-py-1")
from PPreprocess import DatasetPreprocess
import img as util

class HSITextDataset(Dataset):
    def __init__(self, hsi_data, labels, text_data, max_length=64):
        self.hsi_data = hsi_data
        self.labels = labels
        self.text_data = text_data
        self.max_length = max_length

    def __len__(self):
        return len(self.hsi_data)

    def __getitem__(self, idx):
        hsi = self.hsi_data[idx]
        label = self.labels[idx]
        text = self.text_data.iloc[idx].tolist()  # Convert Series to list
        return hsi, text, label


folder_name = '/kaggle/working/Indian Pines'
dataset_preprocess = DatasetPreprocess(folder_name)
data, gt, captions = dataset_preprocess.load_data()

# Split the data
train_fraction = 0.10
rem_classes = [0]
(train_rows, train_cols), (test_rows, test_cols) = util.data_split(gt, train_fraction=train_fraction, rem_classes=rem_classes)

text_csv_path = os.path.join(folder_name, 'Indian_pines.csv')
text_df = pd.read_csv(text_csv_path)

train_text, val_text = util.split_text_data_based_on_spatial(text_df, gt, train_rows, train_cols)

(train_input_sub, y_train_sub), (val_input, y_val), (test_input, y_test) = dataset_preprocess.get_patchify_data(patch_size=9)

train_input_tensor = torch.tensor(train_input_sub, dtype=torch.float32)
val_input_tensor = torch.tensor(val_input, dtype=torch.float32)
train_labels_tensor = torch.tensor(y_train_sub, dtype=torch.long)
val_labels_tensor = torch.tensor(y_val, dtype=torch.long)

# Add a new dimension if necessary
if len(train_input_tensor.shape) == 4:
    train_input_tensor = train_input_tensor.unsqueeze(1)
if len(val_input_tensor.shape) == 4:
    val_input_tensor = val_input_tensor.unsqueeze(1)

# Permute dimensions
if len(train_input_tensor.shape) == 5:
    train_input_tensor = train_input_tensor.permute(0, 1, 3, 4, 2)
if len(val_input_tensor.shape) == 5:
    val_input_tensor = val_input_tensor.permute(0, 1, 3, 4, 2)

train_dataset = HSITextDataset(
    train_input_tensor,
    train_labels_tensor,
    train_text,
    max_length=64
)

val_dataset = HSITextDataset(
    val_input_tensor,
    val_labels_tensor,
    val_text,
    max_length=64
)

train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 16
vision_encoder_name = 'DBCTNet'
bands = 200
merging_method = 'CONCAT'

model = Model(
    num_classes=num_classes,
    vision_encoder_name=vision_encoder_name,
    bands=bands,
    merging_method=merging_method
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    for hsi, text, labels in progress_bar:
        hsi, labels = hsi.to(device), labels.to(device)
        text = [t for t in text]
        optimizer.zero_grad()
        outputs = model(hsi, text[1])
        labels = labels.float()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': total_loss / (progress_bar.n + 1)})
    return total_loss / len(train_loader)

def evaluate_model(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Evaluating", leave=False)
        for hsi, text, labels in progress_bar:
            hsi, labels = hsi.to(device), labels.to(device)
            text = [t for t in text]
            outputs = model(hsi, text[1])
            labels = labels.float()
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            labels_new = torch.argmax(labels, dim=1)
            correct += (predicted == labels_new).sum().item()
            all_labels.extend(labels_new.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            progress_bar.set_postfix({'val_loss': total_loss / (progress_bar.n + 1)})
    accuracy = correct / total
    f1 = f1_score(all_labels, all_predictions, average='weighted')

    precision = precision_score(all_labels, all_predictions, average='weighted')
    recall = recall_score(all_labels, all_predictions, average='weighted')

    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_predictions)

    # Calculate class-wise accuracies
    class_accuracies = conf_matrix.diagonal() / conf_matrix.sum(axis=1)

    return total_loss / len(val_loader), accuracy, f1, precision, recall, conf_matrix, class_accuracies, all_labels, all_predictions

def train(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=25):
    best_accuracy = 0
    results = []
    log_file = open("training_log.txt", "w")

    for epoch in range(num_epochs):
        train_loss = train_model(model, train_loader, criterion, optimizer, device)
        val_loss, val_accuracy, val_f1, precision, recall, conf_matrix, class_accuracies, all_labels, all_predictions = evaluate_model(model, val_loader, criterion, device)
        results.append({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'val_f1': val_f1,
            'precision': precision,
            'recall': recall,
            'conf_matrix': conf_matrix,
            'class_accuracies': class_accuracies,
            'all_labels': all_labels,
            'all_predictions': all_predictions
        })

        # Log and print the results
        log_entry = (f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                     f"Val Accuracy: {val_accuracy:.4f}, Val F1 Score: {val_f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}\n")
        log_file.write(log_entry)
        print(log_entry)

        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model_.pth')

    log_file.close()
    return results

# Train the model and get results
results = train(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=25)

# Extract the last epoch's results
last_epoch_results = results[-1]
conf_matrix = last_epoch_results['conf_matrix']
class_accuracies = last_epoch_results['class_accuracies']
all_labels = last_epoch_results['all_labels']
all_predictions = last_epoch_results['all_predictions']

# Calculate overall accuracy
overall_accuracy = last_epoch_results['val_accuracy']

# Calculate average accuracy
average_accuracy = np.mean(class_accuracies)

# Calculate Kappa Coefficient
kappa_score = cohen_kappa_score(all_labels, all_predictions)

# Print the results
print(f"Overall Accuracy: {overall_accuracy * 100:.4f}%")
print(f"Average Accuracy: {average_accuracy * 100:.4f}%")
print(f"Kappa Coefficient: {kappa_score * 100:.4f}")

# Print class-wise accuracies
print("Class-wise Accuracies:")
for class_idx, accuracy in enumerate(class_accuracies, start=1):
    print(f"Class {class_idx}: {accuracy * 100:.4f}%")


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Data loaded.
CSV file loaded.
Text data splitted based on spatial split.
Using patch size: 9
Patching is done.


Evaluating:  16%|█▌        | 371/2298 [00:06<00:33, 57.17it/s, val_loss=0.0587]

# Prediction & Results

In [None]:
import sys
sys.path.append("/kaggle/input/datapipeline")

import PPreprocess
from PPreprocess import DatasetPreprocess
import img as util

import numpy as np
# Load the data
folder_name = '/kaggle/working/Indian Pines'
dataset_preprocess = DatasetPreprocess(folder_name)
data, gt, captions = dataset_preprocess.load_data()

# Split the data
train_fraction = 0.10
rem_classes = [0]  # Classes to exclude
(train_rows, train_cols), (test_rows, test_cols) = util.data_split(gt, train_fraction=train_fraction, rem_classes=rem_classes)

# Get the labels for training and testing sets
train_labels = gt[train_rows, train_cols]
test_labels = gt[test_rows, test_cols]

# Count the samples for each class
train_counts = np.bincount(train_labels)
\
test_counts = np.bincount(test_labels)

# Print class-wise sample counts for training
print("Class-wise sample counts for training:")
for cls in np.unique(train_labels):
    print(f"Class {cls}: {train_counts[cls]} samples")

# Print class-wise sample counts for testing
print("\nClass-wise sample counts for testing:")
for cls in np.unique(test_labels):
    print(f"Class {cls}: {test_counts[cls]} samples")
