In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import numpy
print("Numpy version imported successfully",numpy.__version__)

import transformers
from transformers import BlipProcessor, BlipForConditionalGeneration
print("Transformers version imported successfully", transformers.__version__)

from PIL import Image
from datasets import load_dataset
print("Datasets loaded")

import easyocr
print("EasyOCR version:", easyocr.__version__)
reader = easyocr.Reader(['en'])

Numpy version imported successfully 1.26.0
Transformers version imported successfully 4.40.0
Datasets loaded




EasyOCR version: 1.7.2
Progress: |██████████████████████████████████████████████████| 100.0% Complete



Progress: |██████████████████████████████████████████████████| 100.0% Complete

In [None]:
import pandas
import os
from tqdm import tqdm
from PIL import Image
import torch
import cv2
import json
from sklearn.cluster import DBSCAN
import numpy as np
from transformers import pipeline

In [None]:
from datasets import load_dataset
import itertools

dataset = load_dataset("VLR-CVC/ComicsPAP", "sequence_filling", split="val", streaming=True)

# OCR + text clsutering + correction :

In [None]:
def preprocess_image(img):
    img_np=np.array(img)
    img_bgr=cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    return img_bgr



reader = easyocr.Reader(['en'])
def text_ocr(image,reader=easyocr.Reader(['en'])):
    image=preprocess_image(image)
    results = reader.readtext(image)
    return results


## regrouper des paroles d'un même personnage
def organiser_ocr(result):
    if not result:
      print("OCR result is empty.")
      return ""

    # 1. calculer le point central de chaque box text
    coords = []
    for img_result in result :
        box=img_result[0]
        text=img_result[1]


        x=sum([pt[0] for pt in box])/4
        y = sum([pt[1] for pt in box]) / 4
        coords.append([x, y])

    coords = np.array(coords)

    # 2. DBSCAN clustering 
    clustering = DBSCAN(eps=60, min_samples=1).fit(coords)
    labels = clustering.labels_  

    texts=[]
    for cluster_id in set(labels):

        indices = np.where(labels == cluster_id)[0]
        text_group=[result[indice][1] for indice in indices]
        texts.extend(text_group)

    return ' '.join(texts).lower()


In [None]:
## corriger des erreurs orthographiques:

from spellchecker import SpellChecker
spell = SpellChecker()
def correct_spelling(text):
    words = text.split()
    corrected_words = []
    for word in words:
        correction = spell.correction(word)
        if correction is None:
            corrected_words.append(word)
        else:
            corrected_words.append(correction)
    return " ".join(corrected_words)


## transformer les textes/images en vecteurs via clip

In [None]:
from transformers import CLIPTokenizer, CLIPTextModel
import torch

clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")

def ocr_text2embedding(text):
    inputs = clip_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = clip_text_encoder (**inputs)
    # 取 [CLS] token 输出，shape: (1, 768)
    return outputs.last_hidden_state[:, 0, :].squeeze(0)


In [None]:
def image_ocr(image):#image2text
    image=preprocess_image(image)
    result=text_ocr(image)#默认reader
    text=organiser_ocr(result)
    text=correct_spelling(text)
    text_embeddings=ocr_text2embedding(text)
    return text, text_embeddings

In [None]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

@torch.no_grad()
def image2embedding(image):
    inputs = clip_processor(images=image, return_tensors="pt")
    outputs = clip_model.get_image_features(**inputs)
    return outputs[0]  # shape: (512,)



In [None]:
def combine_embeddings(img_embed, ocr_embed, mode="concat"):
    if mode == "concat":
        return torch.cat([img_embed, ocr_embed], dim=-1)  # shape: (1024,)
    elif mode == "sum":
        return img_embed + ocr_embed  # shape: (512,)
    elif mode == "average":
        return (img_embed + ocr_embed) / 2
    else:
        raise ValueError("Unknown combination mode")

In [None]:
def safe_panel_embedding(image, i, is_option=False):
    try:
        ocr_text, ocr_embed = image_ocr(image)
        if not ocr_text or not isinstance(ocr_text, str):
            raise ValueError("Invalid OCR result")

        image_embed = image2embedding(image)
        embedding = combine_embeddings(ocr_embed, image_embed)
        if embedding is None:
            raise ValueError("Embedding failed")
        return ocr_text, embedding

    except Exception as e:
        print(f"[Warning] {'option' if is_option else 'context'} OCR/embedding failed at sample {i}: {e}")
        return "[ERROR]", np.zeros((1024,),dtype=np.float32)

dataset_text=[]
dataset_emb=[]
for i,sample in tqdm(enumerate(dataset)):
  index=sample['index']
  solution_index = sample['solution_index']

  context_text=[]
  option_text=[]
  context_embeddings=[]
  option_embeddings=[]

  for image in sample['context']:
    ocr_text, emb=safe_panel_embedding(image, i, is_option=False)
    context_text.append(ocr_text)
    context_embeddings.append(emb)


  for image in sample['options']:
    ocr_text, emb=safe_panel_embedding(image, i, is_option=True)
    option_text.append(ocr_text)
    option_embeddings.append(emb)

  sample_text={
        'context': context_text,
        'options': option_text,
        "index":index,
        'solution_index': solution_index,
    }


  sample_emb={
        'context': [emb.tolist() for emb in context_embeddings],
        'options': [emb.tolist() for emb in option_embeddings],
        "index":index,
        'solution_index':solution_index,
    }
  dataset_text.append(sample_text)
  dataset_emb.append(sample_emb)
  print(f'{i} sample embeded!')


import os

save_dir = "/content/drive/MyDrive/projet_bd/results"
os.makedirs(save_dir, exist_ok=True)

with open("/content/drive/MyDrive/projet_bd/results/dataset_text2.json", 'w') as f:
        json.dump(dataset_text, f, indent=2, ensure_ascii=False)

with open("/content/drive/MyDrive/projet_bd/results/dataset_embeddings2.json", 'w') as f:
        json.dump(dataset_emb, f, indent=2, ensure_ascii=False)



1it [00:16, 16.26s/it]

0 sample embeded!


2it [00:26, 12.89s/it]

1 sample embeded!
OCR result is empty.


3it [00:32,  9.57s/it]

2 sample embeded!


4it [00:53, 14.06s/it]

3 sample embeded!


5it [01:11, 15.48s/it]

4 sample embeded!


6it [01:35, 18.43s/it]

5 sample embeded!


7it [01:53, 18.39s/it]

6 sample embeded!


8it [02:12, 18.40s/it]

7 sample embeded!


9it [02:26, 17.06s/it]

8 sample embeded!


10it [02:38, 15.65s/it]

9 sample embeded!


11it [02:44, 12.65s/it]

OCR result is empty.
10 sample embeded!


12it [03:08, 16.15s/it]

11 sample embeded!
OCR result is empty.


13it [03:22, 15.30s/it]

12 sample embeded!


14it [03:41, 16.55s/it]

13 sample embeded!


15it [04:04, 18.40s/it]

14 sample embeded!


16it [04:22, 18.22s/it]

15 sample embeded!


17it [04:39, 18.02s/it]

16 sample embeded!


18it [04:49, 15.54s/it]

17 sample embeded!
OCR result is empty.


19it [05:00, 14.10s/it]

18 sample embeded!


20it [05:14, 14.25s/it]

19 sample embeded!


21it [05:23, 12.59s/it]

20 sample embeded!


22it [05:29, 10.59s/it]

21 sample embeded!


23it [05:44, 11.97s/it]

22 sample embeded!


24it [05:48,  9.64s/it]

23 sample embeded!


25it [06:13, 14.20s/it]

24 sample embeded!


26it [06:24, 13.23s/it]

25 sample embeded!


27it [06:29, 10.67s/it]

26 sample embeded!


28it [06:39, 10.40s/it]

27 sample embeded!


29it [06:57, 12.83s/it]

28 sample embeded!


30it [07:10, 12.83s/it]

29 sample embeded!


31it [07:22, 12.73s/it]

30 sample embeded!


32it [07:39, 13.91s/it]

31 sample embeded!
OCR result is empty.


33it [07:55, 14.61s/it]

32 sample embeded!


34it [08:20, 17.56s/it]

33 sample embeded!


35it [08:32, 15.92s/it]

34 sample embeded!


36it [08:56, 18.30s/it]

35 sample embeded!


37it [09:13, 18.00s/it]

36 sample embeded!


38it [09:30, 17.75s/it]

37 sample embeded!


39it [09:45, 16.91s/it]

38 sample embeded!


40it [09:56, 15.19s/it]

39 sample embeded!


41it [10:15, 16.15s/it]

40 sample embeded!


42it [10:24, 14.01s/it]

41 sample embeded!
OCR result is empty.


43it [10:33, 12.69s/it]

42 sample embeded!


44it [10:52, 14.61s/it]

43 sample embeded!


45it [11:04, 13.81s/it]

44 sample embeded!


46it [11:14, 12.61s/it]

45 sample embeded!
OCR result is empty.


47it [11:28, 13.03s/it]

46 sample embeded!


48it [11:43, 13.55s/it]

47 sample embeded!


49it [11:53, 12.37s/it]

48 sample embeded!


50it [12:11, 14.05s/it]

49 sample embeded!


51it [12:21, 12.85s/it]

50 sample embeded!
OCR result is empty.


52it [12:24, 10.16s/it]

51 sample embeded!


53it [12:49, 14.42s/it]

52 sample embeded!


54it [13:09, 16.19s/it]

53 sample embeded!


55it [13:26, 16.34s/it]

54 sample embeded!


56it [13:42, 16.27s/it]

55 sample embeded!


57it [13:52, 14.44s/it]

56 sample embeded!


58it [14:11, 15.71s/it]

57 sample embeded!


59it [14:20, 13.63s/it]

58 sample embeded!


60it [14:23, 10.51s/it]

59 sample embeded!


61it [14:38, 11.88s/it]

60 sample embeded!


62it [14:55, 13.54s/it]

61 sample embeded!


63it [15:06, 12.59s/it]

62 sample embeded!


64it [15:14, 11.39s/it]

63 sample embeded!


65it [15:30, 12.69s/it]

64 sample embeded!


66it [15:48, 14.21s/it]

65 sample embeded!


67it [16:07, 15.71s/it]

66 sample embeded!


68it [16:21, 15.26s/it]

67 sample embeded!


69it [16:52, 20.05s/it]

68 sample embeded!


70it [17:01, 16.67s/it]

69 sample embeded!


71it [17:20, 17.44s/it]

70 sample embeded!


72it [18:11, 27.44s/it]

71 sample embeded!


73it [18:21, 22.15s/it]

72 sample embeded!


74it [18:41, 21.46s/it]

73 sample embeded!


75it [19:01, 21.04s/it]

74 sample embeded!


76it [19:17, 19.63s/it]

75 sample embeded!


77it [19:27, 16.80s/it]

76 sample embeded!


78it [19:31, 12.83s/it]

77 sample embeded!


79it [19:56, 16.47s/it]

78 sample embeded!


80it [20:08, 15.04s/it]

79 sample embeded!


81it [20:13, 11.99s/it]

80 sample embeded!
OCR result is empty.


82it [20:34, 14.73s/it]

81 sample embeded!


83it [20:49, 14.87s/it]

82 sample embeded!


84it [21:02, 14.44s/it]

83 sample embeded!


85it [21:28, 17.96s/it]

84 sample embeded!


86it [21:40, 15.93s/it]

85 sample embeded!


87it [21:49, 13.89s/it]

86 sample embeded!


88it [22:08, 15.50s/it]

87 sample embeded!


89it [22:40, 20.57s/it]

88 sample embeded!


90it [22:55, 18.89s/it]

89 sample embeded!


91it [23:17, 19.75s/it]

90 sample embeded!
OCR result is empty.


92it [23:29, 17.26s/it]

91 sample embeded!


93it [23:35, 13.92s/it]

92 sample embeded!
OCR result is empty.


94it [23:46, 13.14s/it]

93 sample embeded!


95it [23:55, 12.00s/it]

94 sample embeded!


96it [24:10, 12.71s/it]

95 sample embeded!


97it [24:27, 14.13s/it]

96 sample embeded!


98it [24:37, 12.98s/it]

97 sample embeded!


99it [24:51, 13.00s/it]

98 sample embeded!


100it [25:02, 12.57s/it]

99 sample embeded!


101it [25:11, 11.38s/it]

100 sample embeded!


102it [25:24, 11.91s/it]

101 sample embeded!


103it [25:36, 12.00s/it]

102 sample embeded!


104it [26:06, 17.35s/it]

103 sample embeded!


105it [26:12, 14.07s/it]

104 sample embeded!


106it [26:30, 15.05s/it]

105 sample embeded!


107it [27:05, 21.09s/it]

106 sample embeded!


108it [27:11, 16.61s/it]

107 sample embeded!


109it [27:21, 14.61s/it]

108 sample embeded!


110it [27:40, 15.91s/it]

109 sample embeded!


111it [27:59, 16.98s/it]

110 sample embeded!


112it [28:24, 19.17s/it]

111 sample embeded!


113it [28:45, 19.91s/it]

112 sample embeded!


114it [28:55, 16.95s/it]

113 sample embeded!


115it [29:29, 21.87s/it]

114 sample embeded!


116it [29:44, 19.91s/it]

115 sample embeded!


117it [30:03, 19.67s/it]

116 sample embeded!


118it [30:19, 18.59s/it]

117 sample embeded!


119it [30:24, 14.47s/it]

118 sample embeded!


120it [30:39, 14.60s/it]

119 sample embeded!
OCR result is empty.


121it [30:56, 15.38s/it]

120 sample embeded!


122it [31:09, 14.74s/it]

121 sample embeded!


123it [31:35, 18.00s/it]

122 sample embeded!


124it [31:51, 17.33s/it]

123 sample embeded!


125it [32:02, 15.59s/it]

124 sample embeded!


126it [32:23, 17.07s/it]

125 sample embeded!


127it [32:37, 16.18s/it]

126 sample embeded!


128it [32:45, 13.81s/it]

127 sample embeded!


129it [33:03, 14.98s/it]

128 sample embeded!


130it [33:20, 15.74s/it]

129 sample embeded!
OCR result is empty.


131it [33:24, 12.24s/it]

130 sample embeded!


132it [33:30, 10.29s/it]

131 sample embeded!
OCR result is empty.


133it [33:38,  9.53s/it]

132 sample embeded!


134it [33:57, 12.36s/it]

133 sample embeded!


135it [34:14, 13.68s/it]

134 sample embeded!


136it [34:29, 14.23s/it]

135 sample embeded!


137it [34:52, 16.66s/it]

136 sample embeded!


138it [35:08, 16.49s/it]

137 sample embeded!


139it [35:38, 20.76s/it]

138 sample embeded!


140it [36:02, 21.59s/it]

139 sample embeded!


141it [36:09, 17.20s/it]

140 sample embeded!


142it [36:31, 18.55s/it]

141 sample embeded!


143it [36:41, 16.09s/it]

142 sample embeded!


144it [36:45, 12.62s/it]

143 sample embeded!


145it [37:01, 13.48s/it]

144 sample embeded!


146it [37:21, 15.39s/it]

145 sample embeded!
OCR result is empty.


147it [37:39, 16.40s/it]

146 sample embeded!


148it [37:58, 17.01s/it]

147 sample embeded!


149it [38:14, 16.60s/it]

148 sample embeded!


150it [38:24, 14.87s/it]

149 sample embeded!


151it [38:48, 17.58s/it]

150 sample embeded!


152it [38:55, 14.28s/it]

151 sample embeded!


153it [38:58, 10.93s/it]

152 sample embeded!


154it [39:13, 12.02s/it]

153 sample embeded!


155it [39:34, 14.72s/it]

154 sample embeded!


156it [39:51, 15.50s/it]

155 sample embeded!


157it [40:05, 15.21s/it]

156 sample embeded!


158it [40:29, 17.57s/it]

157 sample embeded!


159it [40:40, 15.76s/it]

158 sample embeded!


160it [41:02, 17.63s/it]

159 sample embeded!


161it [41:22, 18.22s/it]

160 sample embeded!
OCR result is empty.
OCR result is empty.


162it [41:45, 19.65s/it]

161 sample embeded!


163it [41:59, 18.10s/it]

162 sample embeded!


164it [42:22, 19.53s/it]

163 sample embeded!


165it [42:55, 23.48s/it]

164 sample embeded!


166it [43:10, 21.16s/it]

165 sample embeded!


167it [43:19, 17.24s/it]

166 sample embeded!


168it [43:37, 17.55s/it]

167 sample embeded!


169it [43:48, 15.53s/it]

168 sample embeded!


170it [44:17, 19.69s/it]

169 sample embeded!
OCR result is empty.


171it [44:24, 15.78s/it]

170 sample embeded!


172it [44:43, 16.85s/it]

171 sample embeded!


173it [44:57, 15.94s/it]

172 sample embeded!


174it [45:06, 13.90s/it]

173 sample embeded!


175it [45:26, 15.78s/it]

174 sample embeded!


176it [45:46, 16.98s/it]

175 sample embeded!


177it [45:54, 14.33s/it]

176 sample embeded!


178it [46:10, 14.94s/it]

177 sample embeded!


179it [46:36, 18.07s/it]

178 sample embeded!


180it [46:54, 17.97s/it]

179 sample embeded!


181it [47:10, 17.59s/it]

180 sample embeded!


182it [47:24, 16.39s/it]

181 sample embeded!


183it [47:32, 14.03s/it]

182 sample embeded!


184it [47:50, 15.11s/it]

183 sample embeded!


185it [48:05, 15.08s/it]

184 sample embeded!


186it [48:18, 14.50s/it]

185 sample embeded!


187it [48:21, 11.02s/it]

186 sample embeded!
OCR result is empty.


188it [48:34, 11.49s/it]

187 sample embeded!


189it [48:47, 12.13s/it]

188 sample embeded!


190it [48:59, 12.05s/it]

189 sample embeded!


191it [49:19, 14.44s/it]

190 sample embeded!


192it [49:32, 14.08s/it]

191 sample embeded!


193it [49:49, 14.71s/it]

192 sample embeded!


194it [50:12, 17.44s/it]

193 sample embeded!


195it [50:33, 18.31s/it]

194 sample embeded!


196it [50:48, 17.45s/it]

195 sample embeded!


197it [51:02, 16.43s/it]

196 sample embeded!


198it [51:11, 14.10s/it]

197 sample embeded!


199it [51:23, 13.59s/it]

198 sample embeded!


200it [51:46, 16.44s/it]

199 sample embeded!
OCR result is empty.
OCR result is empty.


201it [51:53, 13.50s/it]

200 sample embeded!


202it [52:09, 14.23s/it]

201 sample embeded!


203it [52:26, 15.07s/it]

202 sample embeded!
OCR result is empty.


204it [52:43, 15.58s/it]

203 sample embeded!


205it [52:56, 14.94s/it]

204 sample embeded!


206it [53:13, 15.65s/it]

205 sample embeded!


207it [53:20, 12.88s/it]

206 sample embeded!


208it [53:40, 14.99s/it]

207 sample embeded!


209it [53:51, 13.96s/it]

208 sample embeded!


210it [53:59, 12.06s/it]

209 sample embeded!
OCR result is empty.


211it [54:02,  9.41s/it]

210 sample embeded!


212it [54:21, 12.15s/it]

211 sample embeded!


213it [54:26, 10.05s/it]

212 sample embeded!


214it [55:04, 18.42s/it]

213 sample embeded!


215it [55:09, 14.44s/it]

214 sample embeded!


216it [55:24, 14.65s/it]

215 sample embeded!


217it [55:45, 16.43s/it]

216 sample embeded!


218it [55:53, 14.09s/it]

217 sample embeded!


219it [56:03, 12.61s/it]

218 sample embeded!
OCR result is empty.


220it [56:06,  9.82s/it]

OCR result is empty.
219 sample embeded!


221it [56:18, 10.50s/it]

220 sample embeded!


222it [56:26,  9.67s/it]

221 sample embeded!


223it [56:43, 11.86s/it]

222 sample embeded!


224it [56:59, 13.11s/it]

223 sample embeded!


225it [57:20, 15.52s/it]

224 sample embeded!


226it [57:42, 17.45s/it]

225 sample embeded!


227it [57:54, 15.87s/it]

226 sample embeded!


228it [58:01, 13.23s/it]

227 sample embeded!


229it [58:07, 11.07s/it]

228 sample embeded!


230it [58:16, 10.32s/it]

229 sample embeded!


231it [58:34, 12.74s/it]

230 sample embeded!


232it [58:44, 11.82s/it]

231 sample embeded!


233it [59:00, 13.25s/it]

232 sample embeded!


234it [59:17, 14.45s/it]

233 sample embeded!


235it [59:37, 16.00s/it]

234 sample embeded!


236it [59:41, 12.46s/it]

235 sample embeded!


237it [59:51, 11.67s/it]

236 sample embeded!


238it [1:00:10, 13.68s/it]

237 sample embeded!


239it [1:00:20, 12.86s/it]

238 sample embeded!


240it [1:00:45, 16.49s/it]

239 sample embeded!


241it [1:00:54, 13.98s/it]

240 sample embeded!


242it [1:01:07, 13.84s/it]

241 sample embeded!


243it [1:01:22, 14.16s/it]

242 sample embeded!


244it [1:01:50, 18.24s/it]

243 sample embeded!


245it [1:02:15, 20.32s/it]

244 sample embeded!


246it [1:02:32, 19.21s/it]

245 sample embeded!


247it [1:02:46, 17.90s/it]

246 sample embeded!


248it [1:03:02, 17.14s/it]

247 sample embeded!


249it [1:03:15, 15.94s/it]

248 sample embeded!


250it [1:03:42, 19.38s/it]

249 sample embeded!


251it [1:04:02, 19.43s/it]

250 sample embeded!
OCR result is empty.


252it [1:04:17, 18.23s/it]

251 sample embeded!


253it [1:04:36, 18.40s/it]

252 sample embeded!


254it [1:05:03, 21.11s/it]

253 sample embeded!


255it [1:05:15, 18.30s/it]

254 sample embeded!


256it [1:05:37, 19.44s/it]

255 sample embeded!


257it [1:05:57, 19.48s/it]

256 sample embeded!


258it [1:06:16, 19.32s/it]

257 sample embeded!


259it [1:06:41, 21.18s/it]

258 sample embeded!


260it [1:07:11, 23.86s/it]

259 sample embeded!


261it [1:07:15, 17.67s/it]

260 sample embeded!


262it [1:07:20, 15.42s/it]

261 sample embeded!





In [None]:
import torch
import json
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class ComicPAPDataset(Dataset):
  def __init__(self,json_path):
    with open(json_path,'r') as f:
      self.samples=json.load(f)
  def __len__(self):
    return len(self.samples)

  def __getitem__(self,idx):
    sample=self.samples[idx]
    context=[torch.tensor(vec, dtype=torch.float32) for vec in sample['context']]
    context=torch.stack(context)
    options=[torch.tensor(vec, dtype=torch.float32) for vec in sample['options']]
    options=torch.stack(options)
    label=torch.tensor(sample['solution_index'],dtype=torch.long)

    return {
        'context':context,
        'options':options,
        'label':label
        }


In [None]:
def pad_collate_fn(batch):
    max_opt_len = max([sample['options'].shape[0] for sample in batch])

    context_batch = []
    options_batch = []
    label_batch = []
    option_mask = []
    for sample in batch:
        context_batch.append(sample['context'])  # shape: (num_ctx, 1024)

        options = sample['options']              # shape: (num_opt, 1024)
        num_opt = options.shape[0]

        # Padding options
        if num_opt < max_opt_len:
            pad = torch.zeros((max_opt_len - num_opt, 1024))
            options = torch.cat([options, pad], dim=0)  # shape: (max_opt_len, 1024)
        options_batch.append(options)

        label_batch.append(sample['label'])

        option_mask.append(torch.tensor([1]*num_opt + [0]*(max_opt_len - num_opt)))

    return {
        'context': torch.stack(context_batch),            # (B, num_ctx, 1024)
        'options': torch.stack(options_batch),            # (B, max_opt_len, 1024)
        'label': torch.tensor(label_batch),              # (B,)
        'option_mask': torch.stack(option_mask)           # (B, max_opt_len)
    }


In [None]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

dataset = ComicPAPDataset(json_path="/content/drive/MyDrive/projet_bd/results/dataset_embeddings2.json")
train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=pad_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=pad_collate_fn)

In [None]:
for batch in train_loader:
    print("Context shape:", batch['context'].shape)
    print("Options shape:", batch['options'].shape)
    print("Option mask shape:", batch['option_mask'].shape)
    print("Label shape:",batch['label'].shape)
    break  # 只看第一个 batch


Context shape: torch.Size([8, 4, 1024])
Options shape: torch.Size([8, 5, 1024])
Option mask shape: torch.Size([8, 5])
Label shape: torch.Size([8])


# modèle MLP

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PanelMLPClassifier(nn.Module):
    def __init__(self, embedding_dim=1024, hidden_dim=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )#mlp对

    def forward(self, context, options, option_mask):
        context_vec = context.mean(dim=1)          # [B, D]context mean pooling
        B, O, D = options.shape

        context_expanded = context_vec.unsqueeze(1).expand(-1, O, -1)  # [B, O, D]
        concat_vec = torch.cat([context_expanded, options], dim=-1)    # [B, O, 2*D]

        logits = self.mlp(concat_vec).squeeze(-1)  # [B, O]

        # minimiser les logits sur les places padding
        logits = logits.masked_fill(option_mask == 0, -1e9)
        return logits


In [None]:
import torch.optim as optim

model=PanelMLPClassifier()
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in dataloader:
        context = batch['context'].to(device)          # [B, C, 1024]
        options = batch['options'].to(device)          # [B, O, 1024]
        option_mask = batch['option_mask'].to(device)  # [B, O]
        labels = batch['label'].to(device)             # [B]

        optimizer.zero_grad()
        logits = model(context, options, option_mask)  # [B, O]
        loss = criterion(logits, labels)#将mlp得到的实数和正确答案放入loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Accuracy
        preds = logits.argmax(dim=1)  # [B]
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    acc = correct / total
    return avg_loss, acc



In [None]:
num_epochs=50
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")

Epoch 1/50
Train Loss: 1.6031 | Train Acc: 0.2105
Epoch 2/50
Train Loss: 1.5899 | Train Acc: 0.3923
Epoch 3/50
Train Loss: 1.5771 | Train Acc: 0.4354
Epoch 4/50
Train Loss: 1.5575 | Train Acc: 0.4498
Epoch 5/50
Train Loss: 1.5285 | Train Acc: 0.4641
Epoch 6/50
Train Loss: 1.4876 | Train Acc: 0.4689
Epoch 7/50
Train Loss: 1.4349 | Train Acc: 0.4737
Epoch 8/50
Train Loss: 1.3746 | Train Acc: 0.4928
Epoch 9/50
Train Loss: 1.3108 | Train Acc: 0.4880
Epoch 10/50
Train Loss: 1.2451 | Train Acc: 0.5215
Epoch 11/50
Train Loss: 1.1775 | Train Acc: 0.5263
Epoch 12/50
Train Loss: 1.1137 | Train Acc: 0.5646
Epoch 13/50
Train Loss: 1.0509 | Train Acc: 0.5837
Epoch 14/50
Train Loss: 0.9948 | Train Acc: 0.6029
Epoch 15/50
Train Loss: 0.9385 | Train Acc: 0.6411
Epoch 16/50
Train Loss: 0.8896 | Train Acc: 0.6890
Epoch 17/50
Train Loss: 0.8414 | Train Acc: 0.7081
Epoch 18/50
Train Loss: 0.7986 | Train Acc: 0.7225
Epoch 19/50
Train Loss: 0.7590 | Train Acc: 0.7321
Epoch 20/50
Train Loss: 0.7361 | Train A

In [None]:
## ATTENTION + MLP:
class AttnMLPClassifier(nn.Module):
    def __init__(self, input_dim=1024):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        self.mlp = nn.Sequential(
            nn.Linear(input_dim * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )

    def forward(self, context, options, option_mask):
        # context: [B, C, D], options: [B, O, D]
        attn_score = self.attn(context).softmax(dim=1)   # [B, C, 1]
        context_vec = (context * attn_score).sum(dim=1)  # [B, D]

        B, O, D = options.size()
        context_exp = context_vec.unsqueeze(1).expand(-1, O, -1)  # [B, O, D]
        concat = torch.cat([context_exp, options], dim=-1)       # [B, O, 2D]

        logits = self.mlp(concat).squeeze(-1)  # [B, O]
        logits = logits.masked_fill(option_mask == 0, -1e9)
        return logits

import torch.optim as optim
model=AttnMLPClassifier()
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in dataloader:
        context = batch['context'].to(device)          # [B, C, 1024]
        options = batch['options'].to(device)          # [B, O, 1024]
        option_mask = batch['option_mask'].to(device)  # [B, O]
        labels = batch['label'].to(device)             # [B]

        optimizer.zero_grad()
        logits = model(context, options, option_mask)  # [B, O]
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Accuracy
        preds = logits.argmax(dim=1)  # [B]
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    acc = correct / total
    return avg_loss, acc



In [None]:
num_epochs=40
for epoch in range(num_epochs):
  train_loss, train_acc = train_one_epoch(model, loader, optimizer, criterion, device)

  print(f"Epoch {epoch+1}/{num_epochs}")
  print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")

Epoch 1/40
Train Loss: 1.6084 | Train Acc: 0.1527
Epoch 2/40
Train Loss: 1.6008 | Train Acc: 0.2176
Epoch 3/40
Train Loss: 1.5860 | Train Acc: 0.2824
Epoch 4/40
Train Loss: 1.5949 | Train Acc: 0.2176
Epoch 5/40
Train Loss: 1.5948 | Train Acc: 0.2405
Epoch 6/40
Train Loss: 1.5936 | Train Acc: 0.2252
Epoch 7/40
Train Loss: 1.5904 | Train Acc: 0.2252
Epoch 8/40
Train Loss: 1.5990 | Train Acc: 0.2710
Epoch 9/40
Train Loss: 1.5933 | Train Acc: 0.2252
Epoch 10/40
Train Loss: 1.5769 | Train Acc: 0.2901
Epoch 11/40
Train Loss: 1.5841 | Train Acc: 0.2710
Epoch 12/40
Train Loss: 1.5772 | Train Acc: 0.3015
Epoch 13/40
Train Loss: 1.5890 | Train Acc: 0.2405
Epoch 14/40
Train Loss: 1.5738 | Train Acc: 0.2710
Epoch 15/40
Train Loss: 1.5692 | Train Acc: 0.3053
Epoch 16/40
Train Loss: 1.5365 | Train Acc: 0.3435
Epoch 17/40
Train Loss: 1.5663 | Train Acc: 0.2863
Epoch 18/40
Train Loss: 1.5470 | Train Acc: 0.3092
Epoch 19/40
Train Loss: 1.5265 | Train Acc: 0.3702
Epoch 20/40
Train Loss: 1.5201 | Train A