In [1]:
import sys
import torch
import os 
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as transforms
# import wandb
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
from scipy.ndimage import gaussian_filter
from PIL import Image
sys.path.append('..')  
import scipy

from types import SimpleNamespace
from model.model import MMC  
from src.config import Config

import os
os.environ["CUDA_VISIBLE_DEVICES"]="6"


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = SimpleNamespace(
    name='MMC',
    dataset='Food101',  # Adjust as per your dataset or requirement
    text_type='abstract',
    mmc='UniSMMC',
    mmc_tao=0.07,
    batch_size=32,
    lr_mm=0.001,
    min_epoch=1,
    valid_step=50,
    max_length=512,
    text_encoder='bert_base',
    image_encoder='vit_base',
    text_out=768,
    img_out=768,
    lr_mm_cls=0.001,
    mm_dropout=0.0,
    lr_text_tfm=2e-5,
    lr_img_tfm=5e-5,
    lr_img_cls=0.0001,
    lr_text_cls=5e-5,
    text_dropout=0.0,
    img_dropout=0.1,
    nplot='',
    data_dir='../datasets/',  # Ensure this path is correct in your notebook environment
    test_only=False,
    pretrained_dir='../pretrained_models',  # Adjust as necessary
    model_save_dir='Path/To/results/models',
    res_save_dir='Path/To/results/results',
    fig_save_dir='Path/To/results/imgs',
    logs_dir='Path/To/results/logs',
    local_rank=-1,
    seeds=None,
    model_path='./Path/To/results/models',
    save_model=True,
    cross_attention=False,
    text_mixup=False,
    image_mixup=False,
    image_embedding_mixup=False,
    alpha=0.2,
    multi_mixup=True,
    mixup_pct=0.33,
    lambda_mixup=0.1,
    mixup_beta=0.15,
    mixup_s_thresh=0.5,
    lr_scheduler='ReduceLROnPlateau',
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)

config = Config(args)
args = config.get_config()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

if args.local_rank == -1:
    device = torch.device("cuda")
else:
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend="nccl")

args.device = device
print(args.data_dir)
args.data_dir = os.path.join(args.data_dir, args.dataset)

#args.best_model_save_path = os.path.join(args.model_save_dir, f'{args.dataset}-best-{time.strftime("%Y%m%d-%H%M%S")}.pth')

../datasets/


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ptflops import get_model_complexity_info
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModel

class TextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased'):
        super(TextEncoder, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.text_encoder = AutoModel.from_pretrained(model_name)
    
    def get_tokenizer(self):
        return self.tokenizer
    
    def forward(self, inputs):
        input_ids, token_type_ids, attention_mask = inputs
        outputs = self.text_encoder(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state

class ImageEncoder(nn.Module):
    def __init__(self, model_name='google/vit-base-patch16-224'):
        super(ImageEncoder, self).__init__()
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
        self.image_encoder = AutoModel.from_pretrained(model_name)
    
    def get_tokenizer(self):
        return self.feature_extractor
    
    def forward(self, pixel_values):
        outputs = self.image_encoder(pixel_values=pixel_values)
        return outputs.last_hidden_state

class MMC(nn.Module):
    def __init__(self, args):
        super(MMC, self).__init__()
        self.text_encoder = TextEncoder(args.text_model_name)
        self.image_encoder = ImageEncoder(args.image_model_name)
        self.classifier = nn.Linear(self.text_encoder.text_encoder.config.hidden_size + 
                                    self.image_encoder.image_encoder.config.hidden_size, 2)

    def forward(self, image, text, target=None):
        # Image encoding
        image_features = self.image_encoder(image)
        image_features = image_features[:, 0, :]  # CLS token
        
        # Text encoding
        text_features = self.text_encoder(text)
        text_features = text_features[:, 0, :]  # CLS token
        
        # Combine features
        combined_features = torch.cat([image_features, text_features], dim=1)
        
        # Classification
        output = self.classifier(combined_features)
        
        if target is not None:
            loss = F.cross_entropy(output, target)
            return output, loss
        return output

class Args:
    def __init__(self):
        self.text_model_name = 'bert-base-uncased'
        self.image_model_name = 'google/vit-base-patch16-224'
        self.device = 'cpu'

args = Args()

# Load your multimodal model
model_path = '/raid/nlp/rajak/Multimodal/UniS-MMC/Path/To/results/models/Food101-best-m3col.pth'
model = MMC(args)
model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
model = model.cpu().eval()

# Define input resolutions for both modalities
image_res = (3, 224, 224)
text_res = 25

# Custom input constructor for multimodal input
def input_constructor(input_res):
    image_input = torch.ones((1, *image_res), dtype=torch.float32)
    
    # Create tensors for text input
    text = "Sample text for tokenization"
    tokenizer = model.text_encoder.get_tokenizer()
    encoded_input = tokenizer(text, return_tensors="pt", max_length=text_res, padding='max_length', truncation=True)
    
    input_ids = encoded_input['input_ids']
    attention_mask = encoded_input['attention_mask']
    token_type_ids = encoded_input.get('token_type_ids', torch.zeros_like(input_ids))
    
    text_input = (input_ids, token_type_ids, attention_mask)
    
    # Create a dummy target tensor
    target = torch.randint(0, 2, (1,), dtype=torch.long)  # Assuming binary classification
    
    return {'image': image_input, 'text': text_input, 'target': target}

# Calculate FLOPs
try:
    flops, params = get_model_complexity_info(
        model, 
        (image_res, text_res),
        input_constructor=input_constructor,
        as_strings=True, 
        print_per_layer_stat=True,
        verbose=True
    )
    print(f"FLOPs: {flops}")
    print(f"Params: {params}")
except Exception as e:
    print(f"An error occurred: {str(e)}")
    import traceback
    traceback.print_exc()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


MMC(
  171.89 M, 87.754% Params, 19.0 GMac, 100.000% MACs, 
  (text_encoder): TextEncoder(
    85.65 M, 43.725% Params, 2.13 GMac, 11.195% MACs, 
    (text_encoder): BertModel(
      85.65 M, 43.725% Params, 2.13 GMac, 11.195% MACs, 
      (embeddings): BertEmbeddings(
        1.54 k, 0.001% Params, 19.2 KMac, 0.000% MACs, 
        (word_embeddings): Embedding(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 30522, 768, padding_idx=0)
        (position_embeddings): Embedding(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 512, 768)
        (token_type_embeddings): Embedding(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 2, 768)
        (LayerNorm): LayerNorm(1.54 k, 0.001% Params, 19.2 KMac, 0.000% MACs, (768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        85.05 M, 43.423% Params, 2.13 GMac, 11.192% MACs, 
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            7.09 