In [2]:
from allennlp.data import DataLoader
from allennlp.data.samplers import BucketBatchSampler
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.training import GradientDescentTrainer
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.params import Params
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import PretrainedTransformerMismatchedEmbedder
from allennlp.data import Instance
from allennlp.models.archival import load_archive

# Load the pre-trained archive
archive = load_archive("spanbert_local/coref-spanbert-large-2021.03.10.tar.gz")
model = archive.model
model.train()
# Make sure to move the model to the correct device
model = model.cuda(0)  # Replace with model.cpu() if you don't want to use the GPU

error loading _jsonnet (this is expected on Windows), treating C:\Users\pc-bae-2\AppData\Local\Temp\tmpvqznjrl0\config.json as plain json


ConfigurationError: coref not in acceptable choices for dataset_reader.type: ['babi', 'conll2003', 'interleaving', 'multitask', 'multitask_shim', 'sequence_tagging', 'sharded', 'text_classification_json']. You should either use the --include-package flag to make sure the correct module is loaded, or use a fully qualified class name in your config file like {"model": "my_module.models.MyModel"} to have it imported automatically.

In [None]:
print(archive.config["dataset_reader"])

train_path = "gap-development.jsonl"  
val_path = "gap-validation.jsonl" 
# Load your training and validation data
reader = DatasetReader.from_params(archive.config["dataset_reader"])

dataset_reader.Params({'type': 'coref', 'max_sentences': 110, 'max_span_width': 30, 'token_indexers': {'tokens': {'type': 'pretrained_transformer_mismatched', 'max_length': 512, 'model_name': 'SpanBERT/spanbert-large-cased'}}})


In [None]:


# # # You can extend the vocabulary here if needed
# # vocab = model.vocab
# # vocab.extend_from_instances(train_data)
# # vocab.extend_from_instances(validation_data)

# # Create data loaders
# from allennlp.data.data_loaders import MultiProcessDataLoader 
# from allennlp.data.samplers import BucketBatchSampler
# from allennlp.data import Vocabulary



# # Make sure your model also uses the same vocabulary
# vocab = model.vocab

# train_sampler = BucketBatchSampler(batch_size=4, sorting_keys=["tokens"])
# print(val_path)
# train_data_loader = MultiProcessDataLoader(reader=reader,data_path=train_path, batch_sampler=train_sampler)
# validation_data_loader = MultiProcessDataLoader(reader=reader,data_path=val_path, batch_size=4)

# # Index the data loaders with the vocabulary
# train_data_loader.index_with(vocab)
# validation_data_loader.index_with(vocab)

# # Make sure your model also uses the same vocabulary
# model.vocab = vocab
# print(model.vocab)

In [None]:
from allennlp.data import DatasetReader, Instance, Field
from allennlp.data.fields import TextField, SpanField, ListField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer, PretrainedTransformerMismatchedIndexer
from allennlp.data.tokenizers import Tokenizer, SpacyTokenizer
from typing import Iterator, List, Dict, Any
import json
from allennlp.data.dataset_readers.dataset_utils import enumerate_spans

class CustomCorefReader(DatasetReader):
    def __init__(self, 
                 tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 **kwargs: Any):
        super().__init__(**kwargs)
        self.tokenizer = tokenizer or SpacyTokenizer()
        # self.token_indexers = token_indexers or {"tokens": PretrainedTransformerMismatchedIndexer(model_name="SpanBERT/spanbert-coreference",max_length=512)}
        self.token_indexers ={ "tokens": PretrainedTransformerMismatchedIndexer(model_name="SpanBERT/spanbert-large-cased",max_length=512)}

    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path, 'r') as f:
            for line in f:
                example = json.loads(line)
                text = example['text']
                clusters = example['clusters']

                tokens_with_offsets = self.tokenizer.tokenize(text)
                tokens = [token.text for token in tokens_with_offsets]
                token_offsets = [(token.idx, token.idx + len(token.text)) for token in tokens_with_offsets]

                text_field = TextField(tokens_with_offsets, self.token_indexers)
                # text_field = [tokens_with_offsets]
                # Convert character-based spans to token-based spans
                token_based_clusters = []
                for cluster in clusters:
                    token_based_cluster = []
                    for char_start, char_end in cluster:
                        try:
                            token_start = next(i for i, (t_start, t_end) in enumerate(token_offsets) if t_start <= char_start < t_end)
                        except StopIteration:
                            print(f"StopIteration for token_start: char_start={char_start}, token_offsets={token_offsets}")
                            raise

                        try:
                            token_end = next(i for i, (t_start, t_end) in enumerate(token_offsets) if t_start < char_end <= t_end)
                        except StopIteration as e:
                            print(f"StopIteration for token_end: char_end={char_end}, token_offsets={token_offsets}")
                            # Assuming `text` contains the actual text content.
                            print(f"Text around char_end: {text[char_end-10:char_end+10]}")

                            continue

                        token_based_cluster.append((token_start, token_end))
                    token_based_clusters.append(token_based_cluster)

                span_fields: List[Field] = []
                for token_based_cluster in token_based_clusters:
                    for start, end in token_based_cluster:
                        span_fields.append(SpanField(start, end, text_field))

                span_list_field = ListField(span_fields)
            
                fields: Dict[str, Field] = {'text': text_field, 'spans': span_list_field}
                yield Instance(fields)

reader = CustomCorefReader()


In [None]:
# Read the training and validation data
train_data = list(reader.read(train_path))

validation_data = list(reader.read(val_path))
print(len(train_data))
# Build a vocabulary from the training data
# vocab = Vocabulary.from_instances(train_data)
vocab = model.vocab
# Create data loaders
from allennlp.data.data_loaders import MultiProcessDataLoader
from allennlp.data.samplers import BucketBatchSampler

train_sampler = BucketBatchSampler(batch_size=4, sorting_keys=["text"])

train_data_loader = MultiProcessDataLoader(reader=reader, data_path=train_path, batch_sampler=train_sampler)
validation_data_loader = MultiProcessDataLoader(reader=reader, data_path=val_path, batch_size=4)

# Index the data loaders with the vocabulary
# train_data_loader.index_with(vocab)
# validation_data_loader.index_with(vocab)

# Make sure your model also uses the same vocabulary
# model.vocab = vocab

# Debugging: Print the number of instances in the data
print("Number of instances in train_data:", len(train_data))
print("Number of instances in validation_data:", len(validation_data))


StopIteration for token_end: char_end=323, token_offsets=[(0, 2), (3, 12), (13, 15), (16, 24), (25, 31), (32, 34), (35, 42), (43, 48), (48, 49), (50, 51), (51, 52), (52, 55), (56, 61), (62, 67), (68, 70), (71, 73), (74, 77), (78, 85), (86, 92), (93, 97), (98, 104), (105, 107), (108, 114), (114, 116), (117, 127), (128, 134), (135, 143), (144, 146), (147, 150), (151, 157), (158, 166), (167, 169), (170, 179), (180, 188), (189, 191), (192, 195), (196, 201), (202, 204), (205, 206), (207, 211), (212, 216), (216, 217), (218, 224), (225, 235), (236, 243), (244, 251), (252, 254), (255, 263), (264, 269), (269, 270), (271, 279), (280, 283), (284, 295), (295, 296), (297, 301), (302, 304), (305, 308), (309, 316), (316, 317), (318, 322), (323, 324), (325, 329), (330, 333), (334, 343), (344, 347), (348, 361), (362, 366), (367, 370), (371, 380), (381, 385), (386, 388), (389, 400), (401, 407), (407, 409), (410, 414), (415, 417), (418, 420), (421, 426), (427, 430), (431, 437), (438, 440), (441, 444), (4

loading instances: 0it [00:00, ?it/s]

StopIteration for token_end: char_end=323, token_offsets=[(0, 2), (3, 12), (13, 15), (16, 24), (25, 31), (32, 34), (35, 42), (43, 48), (48, 49), (50, 51), (51, 52), (52, 55), (56, 61), (62, 67), (68, 70), (71, 73), (74, 77), (78, 85), (86, 92), (93, 97), (98, 104), (105, 107), (108, 114), (114, 116), (117, 127), (128, 134), (135, 143), (144, 146), (147, 150), (151, 157), (158, 166), (167, 169), (170, 179), (180, 188), (189, 191), (192, 195), (196, 201), (202, 204), (205, 206), (207, 211), (212, 216), (216, 217), (218, 224), (225, 235), (236, 243), (244, 251), (252, 254), (255, 263), (264, 269), (269, 270), (271, 279), (280, 283), (284, 295), (295, 296), (297, 301), (302, 304), (305, 308), (309, 316), (316, 317), (318, 322), (323, 324), (325, 329), (330, 333), (334, 343), (344, 347), (348, 361), (362, 366), (367, 370), (371, 380), (381, 385), (386, 388), (389, 400), (401, 407), (407, 409), (410, 414), (415, 417), (418, 420), (421, 426), (427, 430), (431, 437), (438, 440), (441, 444), (4

loading instances: 0it [00:00, ?it/s]

StopIteration for token_end: char_end=285, token_offsets=[(0, 8), (9, 12), (13, 14), (15, 28), (29, 35), (36, 38), (39, 44), (45, 51), (52, 55), (56, 58), (59, 63), (64, 69), (70, 73), (74, 77), (78, 82), (83, 85), (86, 89), (90, 93), (94, 101), (101, 102), (103, 106), (107, 111), (112, 116), (116, 118), (119, 124), (125, 133), (134, 141), (141, 142), (143, 145), (146, 153), (154, 160), (161, 163), (164, 173), (174, 176), (177, 180), (181, 185), (185, 186), (187, 192), (193, 202), (203, 207), (208, 210), (211, 216), (216, 217), (218, 225), (226, 235), (235, 236), (237, 239), (240, 243), (244, 247), (248, 257), (258, 260), (261, 266), (266, 267), (268, 273), (274, 279), (280, 284), (285, 287), (288, 295), (296, 298), (299, 304), (305, 313), (313, 314), (315, 319), (320, 323), (324, 335), (336, 338), (339, 347), (348, 350), (351, 357), (358, 364), (364, 365), (366, 379), (380, 382), (383, 390), (391, 394), (395, 397), (398, 404), (404, 405), (406, 409), (410, 420), (421, 423), (424, 429)

In [None]:
train_data_loader.index_with(vocab)
validation_data_loader.index_with(vocab) 

In [None]:
print(train_data)

[<allennlp.data.instance.Instance object at 0x000002BD4EABEFE0>, <allennlp.data.instance.Instance object at 0x000002BD4EABEDA0>, <allennlp.data.instance.Instance object at 0x000002BD4EABE8F0>, <allennlp.data.instance.Instance object at 0x000002BD56E56530>, <allennlp.data.instance.Instance object at 0x000002BD56E550C0>, <allennlp.data.instance.Instance object at 0x000002BD56E567A0>, <allennlp.data.instance.Instance object at 0x000002BD56E55240>, <allennlp.data.instance.Instance object at 0x000002BD56E54AC0>, <allennlp.data.instance.Instance object at 0x000002BD56E55C30>, <allennlp.data.instance.Instance object at 0x000002BD56E56A40>, <allennlp.data.instance.Instance object at 0x000002BD56E54910>, <allennlp.data.instance.Instance object at 0x000002BD56E56830>, <allennlp.data.instance.Instance object at 0x000002BD56E55D80>, <allennlp.data.instance.Instance object at 0x000002BD56E56B90>, <allennlp.data.instance.Instance object at 0x000002BD56E56D40>, <allennlp.data.instance.Instance object

In [None]:
from torch.optim import AdamW

# Define the AdamW optimizer
optimizer = AdamW(model.parameters(), lr=0.001) 
# Create the trainer and train 
trainer = GradientDescentTrainer(
    model=model,
    optimizer=optimizer,
    data_loader=train_data_loader,
    validation_data_loader=validation_data_loader,
    num_epochs=3,
    cuda_device=0  # set to -1 to use CPU
)

trainer.train()

You provided a validation dataset but patience was set to None, meaning that early stopping is disabled


  0%|          | 0/500 [00:00<?, ?it/s]

RuntimeError: The model you are trying to optimize does not contain a 'loss' key in the output of model.forward(inputs).